# Let's build the transformer's encoder

## 0. Init

In [1]:
import torch
from torch import nn
import torch.nn.functional as F

In [2]:
# Hyperparameters Here
BATCH_SIZE = 128
T = 80                                  # sentence length
D_K = 16
D_V = 16
D_MODEL = 128
H = 8
VOCAB_SIZE = 10
N_MHA_BLOCKS_ENCODER = 6
N_CLASSES = 2                          # classes of classifier

# device = "cuda" if torch.cuda.is_available() else "cpu"
device = "cpu"

## 1. Attention

In [3]:
# Test softmax x axis:
matrix = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])

softmaxed_matrix = F.softmax(matrix, dim=1)

print(softmaxed_matrix)

tensor([[0.0900, 0.2447, 0.6652],
        [0.0900, 0.2447, 0.6652]])


In [4]:
def repeat(x: torch.Tensor, n: int):
    # make shape (n, 1, 1, ...) --> quantity of 1's must be len(x.shape)
    # for example, if shape of x is (3, 4, 8), shapee must be (n, 1, 1, 1)
    tuple_ones = tuple(
        (torch.tensor(x.shape) / torch.tensor(x.shape)).numpy().astype(int)
    )
    # print((n, *tuple_ones))
    return x.unsqueeze(0).repeat((n, *tuple_ones))


def batched_matmul(x_batched, W):
    # # Assuming x_batched.shape == (batch_size, T, d_model)
    # # and W.shape == (d_model, d)
    # batch_size, T, d_model = x_batched.shape
    # d = W.shape[1]

    # # Reshape x_batched to (batch_size * T, d_model)
    # x_reshaped = x_batched.reshape(-1, d_model)

    # # Perform matrix multiplication
    # result = torch.matmul(x_reshaped, W)

    # # Reshape the result back to (batch_size, T, d)
    # result = result.reshape(batch_size, T, d)

    # return result

    # batch_size = x_batched.shape[0]
    # W_repeated = W.unsqueeze(0).repeat((batch_size, 1, 1))
    W_repeated = repeat(W, n=x_batched.shape[0])
    
    return torch.bmm(x_batched, W_repeated)

In [5]:
class Attention(nn.Module):
    """ Convention from: https://www.udemy.com/course/data-science-transformers-nlp/learn/lecture/32255056#overview
    In our convention, K, Q and V are learneable, different from the "Attention is all you need" paper.
    """
    def __init__(self, T: int, d_K, d_V, d_model: int, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        # define a torch 2d tensor initialized normally
        self.W_K = torch.normal(mean=0, std=0.01, size=(d_model, d_K), requires_grad=True)
        self.W_Q = torch.normal(mean=0, std=0.01, size=(d_model, d_K), requires_grad=True)
        self.W_V = torch.normal(mean=0, std=0.01, size=(d_model, d_V), requires_grad=True)
        self.mask = None

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # Shapes:
        # W_K (d_model, d_K)
        # x is a 3d tensor (batch x T x d_model)

        # W_K.T ->  (1, d_k x d_model)
        # x ->      (batch, T, d_model)
        K = batched_matmul(x, self.W_K)
        Q = batched_matmul(x, self.W_Q)
        V = batched_matmul(x, self.W_V)

        # (batch, T, d_model) x (batch, d_model, d_k) -> (batch, T, d_k)
        result = torch.bmm(Q, K.transpose(1, 2)) / (K.shape[-1] ** 0.5)
        if self.mask:
            result = batched_matmul(result, self.mask)
        result = F.softmax(result, dim=-1)
        result = torch.bmm(result, V)
        return result

In [6]:
att = Attention(T=T, d_K=D_K, d_V=D_V, d_model=D_MODEL)
x = torch.normal(mean=0, std=0.01, size=(BATCH_SIZE, T, D_MODEL))

att_result = att.forward(x)
assert att_result.shape == (BATCH_SIZE, T, D_V)
print(att.W_K.shape)
print(att.W_Q.shape)
print(att.W_V.shape)
print(att_result.shape)

torch.Size([128, 16])
torch.Size([128, 16])
torch.Size([128, 16])
torch.Size([128, 80, 16])


## 2. Multi-Head Attention

In [7]:
class MultiHeadAttention(nn.Module):
    def __init__(
        self, h: int, T=T, d_K=D_K, d_V=D_V, d_model=D_MODEL, *args, **kwargs
    ) -> None:
        super().__init__(*args, **kwargs)
        self.h = h
        self.attentions = nn.ModuleList(
            [Attention(T=T, d_K=d_K, d_model=d_model, d_V=d_V) for _ in range(h)]
        )
        self.W_O = torch.normal(0, 0.1, size=(h * d_V, d_model), requires_grad=True)
        self.T = T
        self.d_V = d_V

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        attention_results = []
        
        for attention in self.attentions:
            attention_result = attention(x)
            attention_results.append(attention_result)

        concatenated = torch.concat(attention_results, dim=-1)
        return batched_matmul(concatenated, self.W_O)

In [8]:
mha = MultiHeadAttention(h=H, T=T, d_K=D_K, d_V=D_V, d_model=D_MODEL)
assert mha(x).shape == (BATCH_SIZE, T, D_MODEL)

## 3. The transformer block

In [9]:
class TransformerBlock(nn.Module):
    def __init__(
        self, T=T, d_K=D_K, d_V=D_V, d_model=D_MODEL, h=H, dropout=0.1, *args, **kwargs
    ) -> None:
        super().__init__(*args, **kwargs)
        self.mha = MultiHeadAttention(h, T=T, d_K=d_K, d_V=d_V, d_model=d_model)
        self.layer_norm = nn.LayerNorm(d_model)
        self.ann = nn.Sequential(
            nn.Linear(d_model, d_model),
            nn.Softmax(dim=-1),
        )
    
    def forward(self, x: torch.Tensor):
        x = self.layer_norm(x + self.mha(x))
        x = self.layer_norm(x + self.ann(x))
        return x
        
transformerBlock = TransformerBlock()
transformerBlock(x).shape

torch.Size([128, 80, 128])

## 4. The positional encoding

In [10]:
def PositionalEncoding(T: int, d_model) -> torch.Tensor:
    encodings = torch.zeros(size=(T, d_model), requires_grad=False)
    counter = 0
    for pos in range(T):
        for i in range((d_model // 2) + 1):
            if 2 * i < d_model:
                counter += 1
                encodings[pos, 2 * i] = torch.sin(
                    pos / torch.tensor(10000).pow(2 * i / d_model)
                )
            if 2 * i + 1 < d_model:
                counter += 1
                encodings[pos, 2 * i + 1] = torch.cos(
                    pos / torch.tensor(10000).pow(2 * i / d_model)
                )
    assert counter == T * d_model
    return encodings


PositionalEncoding(T, D_MODEL).shape

torch.Size([80, 128])

In [11]:
torch.range(0, 10).reshape(-1, 1)

  torch.range(0, 10).reshape(-1, 1)


tensor([[ 0.],
        [ 1.],
        [ 2.],
        [ 3.],
        [ 4.],
        [ 5.],
        [ 6.],
        [ 7.],
        [ 8.],
        [ 9.],
        [10.]])

In [12]:
# class Embedding(nn.Module):
#     def __init__(self, vocab_size: int, d_model: int, *args, **kwargs) -> None:
#         super().__init__(*args, **kwargs)
#         self.vocab_size = vocab_size
#         self.embedding = torch.normal(
#             mean=0.0, std=0.1, size=(vocab_size, d_model), requires_grad=True
#         )

#     # TODO: make work in batches
#     def forward(self, x_one_hot: torch.Tensor):
#         # print(x_one_hot.shape)
#         batched_range = torch.arange(self.vocab_size).type(torch.float32)
#         batched_range = batched_range.unsqueeze(0).repeat(x_one_hot.shape[1], 1)
#         # print(batched_range.shape)
#         positions = batched_matmul(x_one_hot, batched_range.transpose(1, 0)).type(torch.int64)

#         print(positions)
#         print(self.embedding.shape)
#         # print(self.embedding)
#         return self.embedding[positions]


# emb = Embedding(3, 2)
# emb.forward(
#     torch.FloatTensor(
#         [
#             [[0, 1, 0], [0, 0, 1], [1, 0, 0], [0, 0, 1]],
#             [[1, 0, 0], [0, 1, 0], [1, 0, 0], [0, 1, 0]],
#         ]
#     )
# )

## 5. The embedding layer

In [13]:
# This Embedding class was designed for one hot encoded inputs
# But this won't be the case. Inputs will be integers.
# So we won't use this
# 
# 
# class Embedding(nn.Module):
#     def __init__(self, vocab_size: int, d_model: int, *args, **kwargs) -> None:
#         super().__init__(*args, **kwargs)
#         self.vocab_size = vocab_size
#         self.embedding = torch.normal(
#             mean=0.0, std=0.1, size=(vocab_size, d_model), requires_grad=True
#         )

#     def forward(self, x_one_hot: torch.Tensor):
#         positions = torch.matmul(
#             x_one_hot, torch.arange(self.vocab_size, dtype=torch.float32)
#         ).type(torch.int64)

#         # print(positions)
#         # print(self.embedding)
#         return self.embedding[positions]


# emb = Embedding(vocab_size=3, d_model=2)
# emb.forward(
#     torch.FloatTensor(
#         [
#             [[0, 1, 0], [0, 0, 1], [1, 0, 0], [0, 0, 1]],
#             [[1, 0, 0], [0, 1, 0], [1, 0, 0], [0, 1, 0]],
#         ]
#     )
# )

In [14]:
class Embedding(nn.Module):
    def __init__(self, vocab_size, d_model, padding_idx=0):
        super(Embedding, self).__init__()
        self.embedding = torch.normal(
            mean=0.0, std=0.1, size=(vocab_size, d_model), requires_grad=True
        )

    def forward(self, x):
        return self.embedding[x.int()]


emb = Embedding(vocab_size=3, d_model=2)
emb.forward(torch.FloatTensor([[2, 1], [1, 0]]))

tensor([[[-0.1451, -0.2171],
         [-0.0268,  0.2124]],

        [[-0.0268,  0.2124],
         [-0.0004,  0.1490]]], grad_fn=<IndexBackward0>)

## 6. The Classification Encoder

In [15]:
class ClassifierEncoder(nn.Module):
    def __init__(
        self,
        T=T,
        d_K=D_K,
        d_V=D_V,
        d_model=D_MODEL,
        h=H,
        vocab_size=VOCAB_SIZE,
        n_classes=N_CLASSES,
        dropout=0.1,
        *args,
        **kwargs
    ):
        super(ClassifierEncoder, self).__init__()
        self.T = T
        self.d_K = d_K

        self.embedding = nn.Embedding(vocab_size, d_model)
        self.position_encoding: torch.Tensor = PositionalEncoding(T, d_model)

        self.transformer_blocks = nn.ModuleList()
        for _ in range(N_MHA_BLOCKS_ENCODER):
            self.transformer_blocks.append(
                TransformerBlock(T=T, d_K=d_K, d_V=d_V, d_model=d_model, h=h)
            )

        self.prediction_head = nn.Linear(d_model, n_classes)

    def forward(self, x: torch.Tensor):
        positionalEncoding = self.position_encoding.repeat(x.shape[0], 1, 1)

        x = positionalEncoding + self.embedding(x)

        for block in self.transformer_blocks:
            x = block(x)

        x = self.prediction_head(x)

        x = x[:, -1, :]

        x = F.softmax(x, dim=-1)
        # Select the last value along the T dimension
        # Because classification is only obtained at the end of the sequence

        return x

# Let's try a classification problem using the CLassification Encoder

* See the file: Fine-Tuning (Intermediate)/Fine-Tunning Sentiment Custom Dataset + Labels.ipynb

## First, lets make a dataset with cross validation

In [16]:
import numpy as np
import pandas as pd

In [17]:
df_ = pd.read_csv("../Fine-Tuning (Intermediate)/AirlineTweets.csv")
df_.head()

Unnamed: 0,tweet_id,airline_sentiment,airline_sentiment_confidence,negativereason,negativereason_confidence,airline,airline_sentiment_gold,name,negativereason_gold,retweet_count,text,tweet_coord,tweet_created,tweet_location,user_timezone
0,570306133677760513,neutral,1.0,,,Virgin America,,cairdin,,0,@VirginAmerica What @dhepburn said.,,2015-02-24 11:35:52 -0800,,Eastern Time (US & Canada)
1,570301130888122368,positive,0.3486,,0.0,Virgin America,,jnardino,,0,@VirginAmerica plus you've added commercials t...,,2015-02-24 11:15:59 -0800,,Pacific Time (US & Canada)
2,570301083672813571,neutral,0.6837,,,Virgin America,,yvonnalynn,,0,@VirginAmerica I didn't today... Must mean I n...,,2015-02-24 11:15:48 -0800,Lets Play,Central Time (US & Canada)
3,570301031407624196,negative,1.0,Bad Flight,0.7033,Virgin America,,jnardino,,0,@VirginAmerica it's really aggressive to blast...,,2015-02-24 11:15:36 -0800,,Pacific Time (US & Canada)
4,570300817074462722,negative,1.0,Can't Tell,1.0,Virgin America,,jnardino,,0,@VirginAmerica and it's a really big bad thing...,,2015-02-24 11:14:45 -0800,,Pacific Time (US & Canada)


In [18]:
df = df_[['airline_sentiment', 'text']].copy()
df.head()

Unnamed: 0,airline_sentiment,text
0,neutral,@VirginAmerica What @dhepburn said.
1,positive,@VirginAmerica plus you've added commercials t...
2,neutral,@VirginAmerica I didn't today... Must mean I n...
3,negative,@VirginAmerica it's really aggressive to blast...
4,negative,@VirginAmerica and it's a really big bad thing...


In [19]:
target_map = {
    'positive': 1,
    'negative': 0,
    'neutral': 2,
}
df['target'] = df['airline_sentiment'].map(target_map)
df.head()

Unnamed: 0,airline_sentiment,text,target
0,neutral,@VirginAmerica What @dhepburn said.,2
1,positive,@VirginAmerica plus you've added commercials t...,1
2,neutral,@VirginAmerica I didn't today... Must mean I n...,2
3,negative,@VirginAmerica it's really aggressive to blast...,0
4,negative,@VirginAmerica and it's a really big bad thing...,0


In [20]:
df_filtered = df[df['target'] != 2]
df_filtered.head()

Unnamed: 0,airline_sentiment,text,target
1,positive,@VirginAmerica plus you've added commercials t...,1
3,negative,@VirginAmerica it's really aggressive to blast...,0
4,negative,@VirginAmerica and it's a really big bad thing...,0
5,negative,@VirginAmerica seriously would pay $30 a fligh...,0
6,positive,"@VirginAmerica yes, nearly every time I fly VX...",1


In [21]:
df2 = df_filtered[['text', 'target']]
# Not documented info: targets must have the column name label
# sentence may have other names, but not label
df2.columns = ['sentence', 'label']
df2.to_csv("data.csv", index=False)
!head data.csv
# df2.head()

sentence,label
@VirginAmerica plus you've added commercials to the experience... tacky.,1
"@VirginAmerica it's really aggressive to blast obnoxious ""entertainment"" in your guests' faces &amp; they have little recourse",0
@VirginAmerica and it's a really big bad thing about it,0
"@VirginAmerica seriously would pay $30 a flight for seats that didn't have this playing.
it's really the only bad thing about flying VA",0
"@VirginAmerica yes, nearly every time I fly VX this “ear worm” won’t go away :)",1
"@virginamerica Well, I didn't…but NOW I DO! :-D",1
"@VirginAmerica it was amazing, and arrived an hour early. You're too good to me.",1
@VirginAmerica I &lt;3 pretty graphics. so much better than minimal iconography. :D,1


In [22]:
from datasets import load_dataset
raw_dataset = load_dataset('csv', data_files="data.csv")

Downloading data files:   0%|          | 0/1 [00:00<?, ?it/s]

Extracting data files:   0%|          | 0/1 [00:00<?, ?it/s]

Generating train split: 0 examples [00:00, ? examples/s]

In [23]:
raw_dataset

DatasetDict({
    train: Dataset({
        features: ['sentence', 'label'],
        num_rows: 11541
    })
})

In [24]:
from datasets import Dataset, DatasetDict

def splitTrainTestValidation(dataset: Dataset, valid_size=.1, test_size=.1):
    len_valid = int(len(dataset) * valid_size)
    len_test = int(len(dataset) * test_size)
    
    splited: DatasetDict = dataset.train_test_split(len_valid + len_test, shuffle=False, seed=42)
    splited['validation'] = splited['test']
    del splited['test']
    
    splited_2 = splited['validation'].train_test_split(len_test, shuffle=True, seed=42)
    splited['validation'] = splited_2['train']
    splited['test'] = splited_2['test']
    
    return splited

In [25]:
# split = raw_dataset['train'].train_test_split(test_size=.3, seed=42)
split = splitTrainTestValidation(raw_dataset['train'], valid_size=.1, test_size=.1)

In [26]:
split

DatasetDict({
    train: Dataset({
        features: ['sentence', 'label'],
        num_rows: 9233
    })
    validation: Dataset({
        features: ['sentence', 'label'],
        num_rows: 1154
    })
    test: Dataset({
        features: ['sentence', 'label'],
        num_rows: 1154
    })
})

In [27]:
from transformers import AutoTokenizer

checkpoint = "distilbert-base-cased"

tokenizer = AutoTokenizer.from_pretrained(checkpoint)


def tokenize_fn(batch):
    return tokenizer(batch["sentence"], truncation=True, padding='max_length', max_length=T)


# tokenizer("This is an example", truncation=True, padding=True)

In [28]:

tokenized_datasets = split.map(tokenize_fn, batched=False)
tokenized_datasets

Map:   0%|          | 0/9233 [00:00<?, ? examples/s]

Map:   0%|          | 0/1154 [00:00<?, ? examples/s]

Map:   0%|          | 0/1154 [00:00<?, ? examples/s]

DatasetDict({
    train: Dataset({
        features: ['sentence', 'label', 'input_ids', 'attention_mask'],
        num_rows: 9233
    })
    validation: Dataset({
        features: ['sentence', 'label', 'input_ids', 'attention_mask'],
        num_rows: 1154
    })
    test: Dataset({
        features: ['sentence', 'label', 'input_ids', 'attention_mask'],
        num_rows: 1154
    })
})

In [29]:
# Little notation here:
# token is an int from the tokenizer
# idx is our index, to use in our embedding

token2idx = {0: 0}
idx2token = {}

all_tokens = [
    element
    for list_ids in tokenized_datasets["train"]["input_ids"]
    for element in list_ids
]
all_tokens = list(set(all_tokens))

token_index = 0
for token in all_tokens:
    if token not in token2idx:
        token2idx[token] = token_index
        idx2token[token_index] = token
        token_index += 1


def filterSplit(splited_dataset):
    """For valid and test datasets, get only those which all inpu_ids is in splited_dataset['train']"""

    for split in ["validation", "test"]:
        # Filter the splited_dataset[split] to only keep the ids which are in splited_dataset['train']
        splited_dataset[split] = splited_dataset[split].filter(
            lambda x: all(token in token2idx for token in x["input_ids"])
        )

    return splited_dataset


filtered_datasets = filterSplit(tokenized_datasets)
filtered_datasets

Filter:   0%|          | 0/1154 [00:00<?, ? examples/s]

Filter:   0%|          | 0/1154 [00:00<?, ? examples/s]

DatasetDict({
    train: Dataset({
        features: ['sentence', 'label', 'input_ids', 'attention_mask'],
        num_rows: 9233
    })
    validation: Dataset({
        features: ['sentence', 'label', 'input_ids', 'attention_mask'],
        num_rows: 853
    })
    test: Dataset({
        features: ['sentence', 'label', 'input_ids', 'attention_mask'],
        num_rows: 844
    })
})

Vamos construir nosso dicionário de tokens

In [30]:
def makeIndex(batch):
    batch["input_idx"] = [token2idx[single] for single in batch["input_ids"]]
    return batch

data = filtered_datasets.map(makeIndex, batched=False)
data

Map:   0%|          | 0/9233 [00:00<?, ? examples/s]

Map:   0%|          | 0/853 [00:00<?, ? examples/s]

Map:   0%|          | 0/844 [00:00<?, ? examples/s]

DatasetDict({
    train: Dataset({
        features: ['sentence', 'label', 'input_ids', 'attention_mask', 'input_idx'],
        num_rows: 9233
    })
    validation: Dataset({
        features: ['sentence', 'label', 'input_ids', 'attention_mask', 'input_idx'],
        num_rows: 853
    })
    test: Dataset({
        features: ['sentence', 'label', 'input_ids', 'attention_mask', 'input_idx'],
        num_rows: 844
    })
})

In [31]:
# Assert all same length
list(set([len(l_idx) for l_idx in data['train']['input_idx']]))

[80]

In [32]:
data_train = torch.Tensor(data['train']['input_idx']).type(torch.int32)
data_train.shape

torch.Size([9233, 80])

In [33]:
import torch.nn.functional as F
target_train = torch.Tensor(data['train']['label'])

#make one_hot
target_train = F.one_hot(target_train.to(torch.int64), num_classes=N_CLASSES).float()

target_train[:10]

tensor([[0., 1.],
        [1., 0.],
        [1., 0.],
        [1., 0.],
        [0., 1.],
        [0., 1.],
        [0., 1.],
        [0., 1.],
        [0., 1.],
        [0., 1.]])

In [34]:
data_valid = torch.Tensor(data['validation']['input_idx']).type(torch.int32)
data_valid.shape

torch.Size([853, 80])

In [35]:
target_valid = torch.Tensor(data['validation']['label'])

#make one_hot
target_valid = F.one_hot(target_valid.to(torch.int64), num_classes=N_CLASSES).float()

target_valid[:10]

tensor([[1., 0.],
        [1., 0.],
        [1., 0.],
        [1., 0.],
        [1., 0.],
        [1., 0.],
        [1., 0.],
        [1., 0.],
        [0., 1.],
        [1., 0.]])

In [36]:
data_test = torch.Tensor(data['test']['input_idx']).type(torch.int32)
data_test.shape

torch.Size([844, 80])

In [37]:
target_test = torch.Tensor(data['test']['label'])

#make one_hot
target_test = F.one_hot(target_test.to(torch.int64), num_classes=N_CLASSES).float()

target_test[:10]

tensor([[1., 0.],
        [1., 0.],
        [1., 0.],
        [1., 0.],
        [1., 0.],
        [1., 0.],
        [0., 1.],
        [1., 0.],
        [1., 0.],
        [1., 0.]])

## Now let's train our model! 

In [41]:
import torch
log_interval = 10
EPOCHS = 10

def train(model, train_loader, optimizer, criterion, epoch):
    model = model.to(device)
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        if batch_idx % log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item()))
            
def evaluate(model, test_loader, criterion):
    model = model.to(device)
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += criterion(output, target).item()
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum()
            
    test_loss /= len(test_loader.dataset)
    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        100))

class MyDataset(torch.utils.data.Dataset):
    def __init__(self, data, target):
        self.data = data
        self.target = target

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx], self.target[idx]

model = ClassifierEncoder(vocab_size=len(token2idx))
train_dataset = MyDataset(data_train, target_train)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)


test_dataset = MyDataset(data_train, target_train)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=True)

optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
criterion = torch.nn.BCELoss()

for epoch in range(1, EPOCHS + 1):
    print("Epoch: ", epoch)
    train(model, train_loader, optimizer, criterion, epoch)
    evaluate(model, test_loader, criterion)



Epoch:  1


RuntimeError: shape '[128, 1]' is invalid for input of size 256