In [17]:
import pandas as pd
import torch
import torch.nn as nn
import numpy as np
import math
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from tqdm import tqdm
import sklearn

In [2]:
DEVICE = "cpu"

In [3]:
csv_file = 'dataset.csv'

# read only 1000 reviews for speed
df = pd.read_csv(csv_file, nrows=1_000, index_col=0)

df

Unnamed: 0,x0,x1,x2,x3,x4,x5,x6,x7,x8,x9,...,x1514,x1515,x1516,x1517,x1518,x1519,x1520,x1521,y0,y1
0,17664,3826,1374,2189,13895,19035,18361,10508,12642,74,...,2738,2738,2738,2738,2738,2738,2738,2738,1,0
1,17664,4823,13840,15853,1657,8712,16992,16992,17664,2189,...,2738,2738,2738,2738,2738,2738,2738,2738,1,0
2,17664,2154,477,1401,12896,4823,13840,13421,6563,17657,...,2738,2738,2738,2738,2738,2738,2738,2738,1,0
3,17664,5958,1526,1851,4823,7663,3945,4823,15853,7046,...,2738,2738,2738,2738,2738,2738,2738,2738,0,1
4,17664,3948,8652,1851,6059,3661,15581,2189,3854,1374,...,2738,2738,2738,2738,2738,2738,2738,2738,1,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
995,17664,3928,10193,8775,8712,17664,6516,14213,12870,2753,...,2738,2738,2738,2738,2738,2738,2738,2738,1,0
996,17664,2154,721,15595,8712,17664,2154,4865,5069,16680,...,2738,2738,2738,2738,2738,2738,2738,2738,0,1
997,17664,2154,6645,3976,6563,12852,4348,18264,16960,12318,...,2738,2738,2738,2738,2738,2738,2738,2738,0,1
998,17664,11590,11443,8581,4790,10193,14059,6563,14699,1401,...,2738,2738,2738,2738,2738,2738,2738,2738,0,1


In [4]:
vocabulary_size = df.max().max() + 1
vocabulary_size

19504

In [5]:
# train - 20 positive, 20 negative
# test - 5 positive, 5 negative

positive_mask = df.y0 == 1
negative_mask = df.y0 == 0

positive_reviews = df[positive_mask]
negative_reviews = df[negative_mask]

train = pd.concat([positive_reviews.iloc[:20,:], negative_reviews.iloc[:20,:]])
test = pd.concat([positive_reviews.iloc[20:25,:], negative_reviews.iloc[20:25,:]])
train

Unnamed: 0,x0,x1,x2,x3,x4,x5,x6,x7,x8,x9,...,x1514,x1515,x1516,x1517,x1518,x1519,x1520,x1521,y0,y1
0,17664,3826,1374,2189,13895,19035,18361,10508,12642,74,...,2738,2738,2738,2738,2738,2738,2738,2738,1,0
1,17664,4823,13840,15853,1657,8712,16992,16992,17664,2189,...,2738,2738,2738,2738,2738,2738,2738,2738,1,0
2,17664,2154,477,1401,12896,4823,13840,13421,6563,17657,...,2738,2738,2738,2738,2738,2738,2738,2738,1,0
4,17664,3948,8652,1851,6059,3661,15581,2189,3854,1374,...,2738,2738,2738,2738,2738,2738,2738,2738,1,0
5,17664,19205,15039,9249,18059,2281,2368,4823,7946,1374,...,2738,2738,2738,2738,2738,2738,2738,2738,1,0
6,17664,2154,8669,7521,8581,6563,14699,4823,5531,1374,...,2738,2738,2738,2738,2738,2738,2738,2738,1,0
9,17664,11590,11443,8581,1696,4145,5294,16537,11443,11141,...,2738,2738,2738,2738,2738,2738,2738,2738,1,0
14,17664,1401,4823,8392,2281,1374,17252,12814,849,2534,...,2738,2738,2738,2738,2738,2738,2738,2738,1,0
16,17664,19088,9970,6516,2252,5127,10526,12852,4617,8712,...,2738,2738,2738,2738,2738,2738,2738,2738,1,0
18,17664,2154,11209,1401,10546,2368,15595,12896,2189,18124,...,2738,2738,2738,2738,2738,2738,2738,2738,1,0


In [6]:
test

Unnamed: 0,x0,x1,x2,x3,x4,x5,x6,x7,x8,x9,...,x1514,x1515,x1516,x1517,x1518,x1519,x1520,x1521,y0,y1
44,17664,1401,2281,1552,9697,11750,4790,8712,17664,9388,...,2738,2738,2738,2738,2738,2738,2738,2738,1,0
45,17664,8351,4823,5907,2368,2154,18816,1078,2189,2281,...,2738,2738,2738,2738,2738,2738,2738,2738,1,0
48,17664,18047,9794,4622,2189,8920,18264,2189,7149,12896,...,2738,2738,2738,2738,2738,2738,2738,2738,1,0
50,17664,5214,6563,2189,865,948,10193,3826,1374,17295,...,2738,2738,2738,2738,2738,2738,2738,2738,1,0
51,17664,13077,13077,13077,4266,13077,13077,13077,12571,6537,...,2738,2738,2738,2738,2738,2738,2738,2738,1,0
39,17664,74,711,14641,1401,8434,1374,3001,2368,15039,...,2738,2738,2738,2738,2738,2738,2738,2738,0,1
40,17664,15595,11529,12571,2189,11828,1374,15707,1374,1401,...,2738,2738,2738,2738,2738,2738,2738,2738,0,1
42,17664,1374,12571,2189,9970,2154,18785,1078,2368,1401,...,2738,2738,2738,2738,2738,2738,2738,2738,0,1
43,17664,2154,11529,9038,7830,12518,15207,6059,5240,1374,...,2738,2738,2738,2738,2738,2738,2738,2738,0,1
46,17664,16433,10193,6732,14279,2281,349,13390,15964,9158,...,2738,2738,2738,2738,2738,2738,2738,2738,0,1


In [7]:
train_x, train_y = train.iloc[:,:-2], train.iloc[:,-2:]
test_x, test_y = test.iloc[:,:-2], test.iloc[:,-2:]
test_y

Unnamed: 0,y0,y1
44,1,0
45,1,0
48,1,0
50,1,0
51,1,0
39,0,1
40,0,1
42,0,1
43,0,1
46,0,1


In [8]:
class TransformerModule(nn.Module):
    def __init__(self, vocabulary_size, max_length, d_model):
        super().__init__()
        self.max_length = max_length
        self.d_model = d_model

        # apparently nn.Embedding should take vocabulary size as first argument
        # if we pass a number greater than the first argument we get an IndexError
        # in our dataset, the number of words greatly exceeds the length of the max sentence
        # so max length is unfit to be first argument here
        self.embedding = nn.Embedding(vocabulary_size, d_model)
        self.relu = nn.ReLU()
        self.softmax = nn.Softmax(dim=1)

        global DEVICE
        self.device = 'cpu'
        if 'DEVICE' in globals():
            self.device = DEVICE

        self.V = nn.Parameter(torch.rand(max_length, d_model).to(self.device))

    def gen_pe(self, n):
        pe = np.zeros((self.max_length, self.d_model), dtype=np.float32)
        for k in np.arange(self.max_length):
            for i in np.arange(self.d_model):
                theta = k / (n ** (i/self.d_model))
                if i % 2 == 0:
                    pe[k,i] = math.sin(theta)
                else:
                    pe[k,i] = math.cos(theta)
        return pe
    
    def forward(self, x: torch.FloatTensor):
        # x is a lookup tensor
        initial_embeddings = self.relu(self.embedding(x))
        positional_embeddings = torch.tensor(self.gen_pe(1000)).to(self.device)
        
        input_embeddings = torch.add(initial_embeddings, positional_embeddings)

        _ = torch.matmul(input_embeddings, torch.transpose(input_embeddings, 0 , 1))
        _ = torch.divide(_, math.sqrt(self.d_model))
        _ = self.softmax(_)

        return self.relu(torch.matmul(_,self.V))

In [9]:
class ClassificationHead(nn.Module):
    def __init__(self, embedding_length, head_length):
        super().__init__()
        self.embedding_length = embedding_length
        self.head_length = head_length

        self.output_layer = nn.Linear(self.embedding_length, self.head_length)
    
    def forward(self, x):
        y = self.output_layer(x)

        # no more sigmoid; we will use BCEWithLogitsLoss which uses a starting sigmoid layer and is more numerically stable
        return y

In [10]:
class Classifier(nn.Module):
    def __init__(self, vocabulary_size, max_length, d_model, head_length):
        super().__init__()
        self.vocabulary_size = vocabulary_size
        self.max_length = max_length
        self.d_model = d_model
        self.head_length = head_length
        self.expected_unravel_length = max_length * d_model
        
        self.body = TransformerModule(self.vocabulary_size, self.max_length, self.d_model)
        self.head = ClassificationHead(self.expected_unravel_length, self.head_length)

    def forward(self, x):
        x = self.body(x)
        x = x.view(-1)
        x = self.head(x)
        # softmax removed; inherently applied by CE, and not applied by BCE
        return x

In [11]:
class MyCustomDataset(Dataset):
    def __init__(self, x, y):
        self.x = torch.Tensor(x)
        self.y = torch.Tensor(y)
        self.n_samples = len(x)
    
    def __getitem__(self, index):
        return self.x[index], self.y[index]
    
    def __len__(self):
        return self.n_samples

custom_dataset = MyCustomDataset(train_x.values, train_y.values)

loader = DataLoader(
    custom_dataset,
    batch_size = 5,
    shuffle = False
)

In [12]:
def train_fn(loader, model, optimizer, loss_fn, device="cpu"):
    loop = tqdm(loader)

    average_loss = 0
    count = 0

    for batch_idx, (data, targets) in enumerate(loop):
        data = data.to(device=device).long()
        targets = targets.to(device=device)
        # print(data.shape)
        # print(targets.shape)

        for inst_idx in range(data.shape[0]):
            # Forward
            predictions = model.forward(data[inst_idx])
            loss = loss_fn(predictions, targets[inst_idx])
            # Backward
            optimizer.zero_grad()

            loss.backward()

            optimizer.step()

            # Update tqdm
            loop.set_postfix(loss=loss.item())

            average_loss += loss.item()
            count += 1

    average_loss = average_loss / count

    return average_loss

##### Note that starting `Epoch 2` up until `Epoch 100` for every n s.t. n % 2 = 0, ave_loss(n + 1) <  ave_loss(n), thus showing that average loss is going down per epoch.

In [13]:
NUM_EPOCHS = 100
EMBEDDING_LENGTH = 100
model = Classifier(vocabulary_size, train_x.shape[1], EMBEDDING_LENGTH, train_y.shape[1]).to(DEVICE)
optimizer = torch.optim.Adam(model.parameters())
criterion = nn.BCEWithLogitsLoss()

for i in range(NUM_EPOCHS):
    ave_loss = train_fn(loader, model, optimizer, criterion, device=DEVICE)
    print(f'Epoch {i+1}: {ave_loss}')

100%|██████████| 8/8 [00:04<00:00,  1.62it/s, loss=0]   


Epoch 1: 43.29619160294533


100%|██████████| 8/8 [00:06<00:00,  1.17it/s, loss=0]      


Epoch 2: 223.03058290480658


100%|██████████| 8/8 [00:04<00:00,  1.65it/s, loss=0]   


Epoch 3: 161.9427566051483


100%|██████████| 8/8 [00:05<00:00,  1.42it/s, loss=0]   


Epoch 4: 138.77342677116394


100%|██████████| 8/8 [00:04<00:00,  1.75it/s, loss=0]   


Epoch 5: 121.69978597164155


100%|██████████| 8/8 [00:05<00:00,  1.40it/s, loss=0]   


Epoch 6: 107.1920020222664


100%|██████████| 8/8 [00:07<00:00,  1.10it/s, loss=0]   


Epoch 7: 94.2774814248085


100%|██████████| 8/8 [00:04<00:00,  1.70it/s, loss=0]      


Epoch 8: 82.44003247916699


100%|██████████| 8/8 [00:06<00:00,  1.18it/s, loss=0]      


Epoch 9: 71.41234919869513


100%|██████████| 8/8 [00:08<00:00,  1.10s/it, loss=0]      


Epoch 10: 56.11330948454558


100%|██████████| 8/8 [00:09<00:00,  1.16s/it, loss=0]      


Epoch 11: 43.10703712268672


100%|██████████| 8/8 [00:06<00:00,  1.33it/s, loss=0]      


Epoch 12: 37.04046649020493


100%|██████████| 8/8 [00:08<00:00,  1.06s/it, loss=0]      


Epoch 13: 31.41863469779488


100%|██████████| 8/8 [00:04<00:00,  1.84it/s, loss=0]      


Epoch 14: 25.447931685764342


100%|██████████| 8/8 [00:03<00:00,  2.12it/s, loss=0]      


Epoch 15: 14.458094190595673


100%|██████████| 8/8 [00:03<00:00,  2.03it/s, loss=0]       


Epoch 16: 6.100048968330036


100%|██████████| 8/8 [00:03<00:00,  2.01it/s, loss=0]       


Epoch 17: 1.4219501994264996


100%|██████████| 8/8 [00:04<00:00,  1.98it/s, loss=0]      


Epoch 18: 0.9200806865583246


100%|██████████| 8/8 [00:04<00:00,  1.93it/s, loss=0]       


Epoch 19: 0.9073669112531254


100%|██████████| 8/8 [00:03<00:00,  2.09it/s, loss=0]       


Epoch 20: 0.4330768512851776


100%|██████████| 8/8 [00:03<00:00,  2.05it/s, loss=1.79e-7] 


Epoch 21: 0.5492529778320424


100%|██████████| 8/8 [00:04<00:00,  1.94it/s, loss=0]       


Epoch 22: 0.4112312293113291


100%|██████████| 8/8 [00:03<00:00,  2.07it/s, loss=0]       


Epoch 23: 0.444729820684305


100%|██████████| 8/8 [00:03<00:00,  2.14it/s, loss=1.19e-7] 


Epoch 24: 0.43724975983395475


100%|██████████| 8/8 [00:03<00:00,  2.13it/s, loss=4.41e-6] 


Epoch 25: 0.6781935772851511


100%|██████████| 8/8 [00:03<00:00,  2.07it/s, loss=5.96e-8] 


Epoch 26: 1.0192404816778484


100%|██████████| 8/8 [00:03<00:00,  2.10it/s, loss=1.43e-6] 


Epoch 27: 0.4100311701341862


100%|██████████| 8/8 [00:03<00:00,  2.12it/s, loss=1.34e-5] 


Epoch 28: 0.5965887465123684


100%|██████████| 8/8 [00:03<00:00,  2.12it/s, loss=2.15e-6] 


Epoch 29: 0.5160801251570312


100%|██████████| 8/8 [00:03<00:00,  2.09it/s, loss=3.7e-6]  


Epoch 30: 0.7748133753576582


100%|██████████| 8/8 [00:03<00:00,  2.16it/s, loss=0.0147]  


Epoch 31: 0.6676001381813563


100%|██████████| 8/8 [00:03<00:00,  2.10it/s, loss=0.0821]  


Epoch 32: 0.7222008542491536


100%|██████████| 8/8 [00:03<00:00,  2.13it/s, loss=0.022]   


Epoch 33: 0.23199189735245157


100%|██████████| 8/8 [00:03<00:00,  2.04it/s, loss=0]       


Epoch 34: 1.5333949699591897e-05


100%|██████████| 8/8 [00:03<00:00,  2.05it/s, loss=0]       


Epoch 35: 6.766274890903645e-05


100%|██████████| 8/8 [00:03<00:00,  2.11it/s, loss=0]       


Epoch 36: 6.490343075054028e-05


100%|██████████| 8/8 [00:04<00:00,  1.95it/s, loss=0]      


Epoch 37: 6.015028483741247e-05


100%|██████████| 8/8 [00:04<00:00,  1.99it/s, loss=0]       


Epoch 38: 5.597911512662534e-05


100%|██████████| 8/8 [00:03<00:00,  2.03it/s, loss=0]       


Epoch 39: 5.232332914086868e-05


100%|██████████| 8/8 [00:03<00:00,  2.09it/s, loss=0]       


Epoch 40: 4.9083496555191175e-05


100%|██████████| 8/8 [00:03<00:00,  2.09it/s, loss=0]       


Epoch 41: 4.619434475392481e-05


100%|██████████| 8/8 [00:03<00:00,  2.11it/s, loss=0]       


Epoch 42: 4.359946178231766e-05


100%|██████████| 8/8 [00:03<00:00,  2.08it/s, loss=0]       


Epoch 43: 4.1254298407178425e-05


100%|██████████| 8/8 [00:03<00:00,  2.04it/s, loss=0]      


Epoch 44: 3.912173196658486e-05


100%|██████████| 8/8 [00:04<00:00,  1.98it/s, loss=0]       


Epoch 45: 3.717502695144148e-05


100%|██████████| 8/8 [00:04<00:00,  1.94it/s, loss=0]       


Epoch 46: 3.539041499172058e-05


100%|██████████| 8/8 [00:04<00:00,  1.94it/s, loss=0]       


Epoch 47: 3.374709412602428e-05


100%|██████████| 8/8 [00:04<00:00,  1.94it/s, loss=0]       


Epoch 48: 3.222871018371265e-05


100%|██████████| 8/8 [00:03<00:00,  2.09it/s, loss=0]       


Epoch 49: 3.0821897958244195e-05


100%|██████████| 8/8 [00:03<00:00,  2.04it/s, loss=0]       


Epoch 50: 2.951178102668095e-05


100%|██████████| 8/8 [00:03<00:00,  2.05it/s, loss=0]       


Epoch 51: 2.8289442703677282e-05


100%|██████████| 8/8 [00:03<00:00,  2.09it/s, loss=0]       


Epoch 52: 2.7147453507492257e-05


100%|██████████| 8/8 [00:03<00:00,  2.13it/s, loss=0]       


Epoch 53: 2.6079866071082593e-05


100%|██████████| 8/8 [00:03<00:00,  2.04it/s, loss=0]       


Epoch 54: 2.5076268499102382e-05


100%|██████████| 8/8 [00:04<00:00,  1.96it/s, loss=0]       


Epoch 55: 2.4130709058223233e-05


100%|██████████| 8/8 [00:04<00:00,  1.98it/s, loss=0]       


Epoch 56: 2.3240216974951265e-05


100%|██████████| 8/8 [00:03<00:00,  2.10it/s, loss=0]       


Epoch 57: 2.240033135967323e-05


100%|██████████| 8/8 [00:03<00:00,  2.14it/s, loss=0]       


Epoch 58: 2.1606586228894997e-05


100%|██████████| 8/8 [00:03<00:00,  2.14it/s, loss=0]       


Epoch 59: 2.0851540462896168e-05


100%|██████████| 8/8 [00:04<00:00,  1.94it/s, loss=0]       


Epoch 60: 2.0135197697523922e-05


100%|██████████| 8/8 [00:04<00:00,  1.99it/s, loss=0]       


Epoch 61: 1.9453092678300975e-05


100%|██████████| 8/8 [00:03<00:00,  2.08it/s, loss=0]       


Epoch 62: 1.8808206365061863e-05


100%|██████████| 8/8 [00:04<00:00,  1.94it/s, loss=0]       


Epoch 63: 1.819309436177363e-05


100%|██████████| 8/8 [00:04<00:00,  1.90it/s, loss=0]       


Epoch 64: 1.7606270917980284e-05


100%|██████████| 8/8 [00:03<00:00,  2.12it/s, loss=0]       


Epoch 65: 1.704475580126541e-05


100%|██████████| 8/8 [00:03<00:00,  2.07it/s, loss=0]       


Epoch 66: 1.6511531790541766e-05


100%|██████████| 8/8 [00:04<00:00,  1.96it/s, loss=0]       


Epoch 67: 1.5999151579393355e-05


100%|██████████| 8/8 [00:04<00:00,  1.99it/s, loss=0]       


Epoch 68: 1.5510595399348404e-05


100%|██████████| 8/8 [00:03<00:00,  2.00it/s, loss=0]       


Epoch 69: 1.5041395813852887e-05


100%|██████████| 8/8 [00:04<00:00,  1.93it/s, loss=0]       


Epoch 70: 1.4591553550502568e-05


100%|██████████| 8/8 [00:03<00:00,  2.04it/s, loss=0]       


Epoch 71: 1.4159579946948498e-05


100%|██████████| 8/8 [00:04<00:00,  2.00it/s, loss=0]       


Epoch 72: 1.3746966575212127e-05


100%|██████████| 8/8 [00:04<00:00,  1.99it/s, loss=0]       


Epoch 73: 1.334775442600744e-05


100%|██████████| 8/8 [00:03<00:00,  2.05it/s, loss=0]       


Epoch 74: 1.2964923367775327e-05


100%|██████████| 8/8 [00:03<00:00,  2.03it/s, loss=0]       


Epoch 75: 1.2596984011992163e-05


100%|██████████| 8/8 [00:03<00:00,  2.03it/s, loss=0]       


Epoch 76: 1.2243939267619907e-05


100%|██████████| 8/8 [00:03<00:00,  2.05it/s, loss=0]       


Epoch 77: 1.1904296110998302e-05


100%|██████████| 8/8 [00:03<00:00,  2.04it/s, loss=0]       


Epoch 78: 1.157507831024418e-05


100%|██████████| 8/8 [00:03<00:00,  2.05it/s, loss=0]       


Epoch 79: 1.1257774162487521e-05


100%|██████████| 8/8 [00:03<00:00,  2.05it/s, loss=0]       


Epoch 80: 1.0952385486717731e-05


100%|██████████| 8/8 [00:04<00:00,  1.98it/s, loss=0]       


Epoch 81: 1.0657421803017542e-05


100%|██████████| 8/8 [00:03<00:00,  2.15it/s, loss=0]       


Epoch 82: 1.0371394814256974e-05


100%|██████████| 8/8 [00:03<00:00,  2.05it/s, loss=0]       


Epoch 83: 1.0097284205556889e-05


100%|██████████| 8/8 [00:03<00:00,  2.13it/s, loss=0]       


Epoch 84: 9.832109201823868e-06


100%|██████████| 8/8 [00:03<00:00,  2.13it/s, loss=0]       


Epoch 85: 9.571400726926526e-06


100%|██████████| 8/8 [00:03<00:00,  2.13it/s, loss=0]       


Epoch 86: 9.325589592990013e-06


100%|██████████| 8/8 [00:03<00:00,  2.17it/s, loss=0]       


Epoch 87: 9.082755780909224e-06


100%|██████████| 8/8 [00:03<00:00,  2.08it/s, loss=0]       


Epoch 88: 8.850348234190619e-06


100%|██████████| 8/8 [00:03<00:00,  2.12it/s, loss=0]       


Epoch 89: 8.62538981145633e-06


100%|██████████| 8/8 [00:03<00:00,  2.10it/s, loss=0]       


Epoch 90: 8.407877968963362e-06


100%|██████████| 8/8 [00:03<00:00,  2.05it/s, loss=0]       


Epoch 91: 8.19632477195853e-06


100%|██████████| 8/8 [00:03<00:00,  2.06it/s, loss=0]       


Epoch 92: 7.990729492846072e-06


100%|██████████| 8/8 [00:03<00:00,  2.12it/s, loss=0]       


Epoch 93: 7.792581883947492e-06


100%|██████████| 8/8 [00:03<00:00,  2.12it/s, loss=0]       


Epoch 94: 7.598903168215543e-06


100%|██████████| 8/8 [00:03<00:00,  2.16it/s, loss=0]      


Epoch 95: 7.412672848872148e-06


100%|██████████| 8/8 [00:03<00:00,  2.01it/s, loss=0]       


Epoch 96: 7.23091124221753e-06


100%|██████████| 8/8 [00:03<00:00,  2.02it/s, loss=0]      


Epoch 97: 7.055107917253167e-06


100%|██████████| 8/8 [00:04<00:00,  1.95it/s, loss=0]       


Epoch 98: 6.8837734854554355e-06


100%|██████████| 8/8 [00:03<00:00,  2.04it/s, loss=0]      


Epoch 99: 6.718397517246899e-06


100%|██████████| 8/8 [00:03<00:00,  2.03it/s, loss=0]       

Epoch 100: 6.5574911669585845e-06





In [14]:
### TODO
# change NUM_EPOCHS to 100 and run it! 
# note if you do not have a CUDA gpu or a CUDA installation of torch it might take long
# and you would have to change cell 2 to DEVICE = "cpu"
# the model needs to be trained first so we can evaluate it for number 4!

# Refactoring Aly's code to use sklearn

In [15]:
# # Given the trained model, evaluate the testing data using a confusion matrix 

# test_tensor = torch.Tensor(test.values).to(DEVICE).long()

# # Set model to evaluation mode
# model.eval()

# # Make predictions on testing data
# with torch.no_grad():
#     predictions = model(test_tensor) # this line raises an error

# # Convert predictions to numpy array
# predictions_np = predictions.cpu().numpy()

# # Convert predicted probabilities to predicted labels
# predicted_labels = np.argmax(predictions_np, axis=1)

# # Convert true labels to numpy array
# true_labels = test_y.iloc[:, 0].values  # Assuming the first column contains true labels

# # Compute confusion matrix
# conf_matrix = confusion_matrix(true_labels, predicted_labels)

# # Print confusion matrix
# print("Confusion Matrix:")
# print(conf_matrix)



RuntimeError: The size of tensor a (1524) must match the size of tensor b (1522) at non-singleton dimension 1