In [419]:
import torch
from torch import optim
import torchvision
from torchvision.transforms import v2
import torch.nn as nn
from PIL import Image
import pandas as pd
print(torchvision.__version__)
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm

0.18.0+cu121


In [420]:
classes = ('anger', 'contempt', 'disgust', 'fear', 'happy', 'neutral', 'sad', 'surprise')

In [421]:
def preprocess(image):
    transform = v2.Compose([
        v2.ToImage() ,
        v2.ToDtype(torch.uint8, scale=True),
        v2.CenterCrop((96, 96)),
        v2.ToTensor(),
        v2.ToDtype(torch.float32, scale=True),
        v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])
    image = transform(image)
    return image

In [422]:
img = Image.open("data/archive/surprise/ffhq_238.png").convert('RGB')
img_tensor = preprocess(img)
img_tensor.shape




torch.Size([3, 96, 96])

In [423]:
def create_one_hot_encoding(label, classes):
    one_hot = torch.zeros(len(classes), dtype=torch.float32)
    
    if label in classes:
        index = classes.index(label)
        one_hot[index] = 1.0
    return one_hot

In [424]:
      
class CustomFERDataset(Dataset):
    def __init__(self, image_parent_directory, data_directory, transform=None):
        self.image_parent_directory = image_parent_directory
        self.df = pd.read_csv(data_directory)
        self.transform = transform
        self.classes = classes

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        img_name = self.image_parent_directory + self.df.iloc[idx, 0]
        image = Image.open(img_name).convert('RGB')
        
        if self.transform:
            image = self.transform(image)
        
        label = self.df.iloc[idx, 1]
        label_tensor = torch.tensor(create_one_hot_encoding(label, self.classes))
        
        return image, label_tensor

In [425]:
train_data = CustomFERDataset("data/archive/" , "data/train.csv" , transform=preprocess)
train_loader = DataLoader(train_data, batch_size=32, shuffle=True, num_workers=2)


In [426]:
class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        
        self.layer1 = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.Conv2d(32, 32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.Conv2d(32, 32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )
        
        self.layer2 = nn.Sequential(
            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )
        
        self.conv7 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.conv8 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
        self.conv9 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
        self.bn7 = nn.BatchNorm2d(128)
        self.bn8 = nn.BatchNorm2d(128)
        self.bn9 = nn.BatchNorm2d(128)
        self.relu = nn.ReLU()
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)

        self.fc1 = nn.Linear(128 * 12 * 12, 512)
        self.fc2 = nn.Linear(512, 8)
        self.dropout = nn.Dropout(0.25)

    def forward(self, x):
        x = self.layer1(x)
        x = self.layer2(x)
        
        residual = self.relu(self.bn7(self.conv7(x)))
        out = self.relu(self.bn8(self.conv8(residual)))
        out = self.bn9(self.conv9(out))
        out += residual
        x = self.pool(self.relu(out))
        
        x = x.view(-1, 128 * 12 * 12)
        x = self.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)
        return x


In [427]:
model = SimpleCNN()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

In [428]:
def train(model, train_loader, criterion, optimizer, epoch):
    model.train()
    for batch_idx, (data, target) in tqdm(enumerate(train_loader)):
        # if batch_idx ==100:
        optimizer.zero_grad()
        output = model(data)
        # print(output.shape)
        # print(target)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        if batch_idx % 100 == 0:
            print(f'Train Epoch: {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)} '
                f'({100. * batch_idx / len(train_loader):.0f}%)]\tLoss: {loss.item():.6f}')

In [429]:
def test(model, test_loader, criterion):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            output = model(data)
            if test_loss == 0:  # Only print for the first batch
                print(f"Data shape: {data.shape}")
                print(f"Target shape: {target.shape}")
                print(f"Output shape: {output.shape}")
                print(f"Output (first 5): {output[:5]}")
                print(f"Target (first 5): {target[:5]}")
            test_loss += criterion(output, target).item()  # sum up batch loss
            _,pred = output.max(1)
            _, target_labels = target.max(1)# get the index of the max log-probability
            correct += pred.eq(target_labels).sum().item()

    test_loss /= len(test_loader.dataset)
    print(f'\nTest set: Average loss: {test_loss:.4f}, Accuracy: {correct}/{len(test_loader.dataset)} '
          f'({100. * correct / len(test_loader.dataset):.0f}%)\n')

In [430]:
test_data = CustomFERDataset("data/archive/" , "data/test.csv" , transform=preprocess)
test_loader = DataLoader(test_data, batch_size=32, shuffle=False, num_workers=2)

In [431]:
num_epochs = 10
for epoch in range(1, num_epochs + 1):
    train(model, train_loader, criterion, optimizer, epoch)
    test(model, test_loader, criterion)

torch.save(model.state_dict(), "simple_cnn.pth")

  label_tensor = torch.tensor(create_one_hot_encoding(label, self.classes))
  label_tensor = torch.tensor(create_one_hot_encoding(label, self.classes))
1it [00:01,  1.45s/it]



101it [02:18,  1.60s/it]



201it [04:50,  1.60s/it]



301it [07:12,  1.50s/it]



401it [09:34,  1.42s/it]



501it [11:42,  1.22s/it]



601it [13:48,  1.30s/it]



701it [15:57,  1.39s/it]



705it [16:03,  1.37s/it]
  label_tensor = torch.tensor(create_one_hot_encoding(label, self.classes))
  label_tensor = torch.tensor(create_one_hot_encoding(label, self.classes))


Data shape: torch.Size([32, 3, 96, 96])
Target shape: torch.Size([32, 8])
Output shape: torch.Size([32, 8])
Output (first 5): tensor([[-0.4710,  0.3178, -0.1207, -0.8079,  0.6070,  0.2783, -0.3981,  0.3680],
        [ 0.8582,  0.7894,  0.8609,  1.0823, -0.8360,  0.2878,  0.7046,  1.3613],
        [ 0.3311,  0.2078,  0.2971,  0.2899, -0.2189,  0.0081,  0.2298,  0.4742],
        [ 1.0884,  0.9574,  1.1036,  1.3661, -0.8987,  0.4264,  0.8964,  1.6825],
        [-0.4396,  0.1868, -0.1417, -0.7802,  0.5249,  0.1371, -0.4102,  0.1952]])
Target (first 5): tensor([[0., 0., 0., 0., 0., 0., 0., 1.],
        [0., 0., 0., 0., 0., 0., 0., 1.],
        [0., 0., 0., 1., 0., 0., 0., 0.],
        [1., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 1., 0., 0., 0.]])

Test set: Average loss: 0.0617, Accuracy: 671/2818 (24%)



  label_tensor = torch.tensor(create_one_hot_encoding(label, self.classes))
  label_tensor = torch.tensor(create_one_hot_encoding(label, self.classes))
1it [00:01,  1.51s/it]



101it [02:15,  1.31s/it]



201it [04:28,  1.26s/it]



301it [06:40,  1.34s/it]



401it [08:50,  1.30s/it]



501it [11:04,  1.27s/it]



601it [13:16,  1.34s/it]



701it [15:30,  1.34s/it]



705it [15:35,  1.33s/it]
  label_tensor = torch.tensor(create_one_hot_encoding(label, self.classes))
  label_tensor = torch.tensor(create_one_hot_encoding(label, self.classes))


Data shape: torch.Size([32, 3, 96, 96])
Target shape: torch.Size([32, 8])
Output shape: torch.Size([32, 8])
Output (first 5): tensor([[-0.2847,  0.5774,  0.1787, -0.2543,  0.2153,  0.2943, -0.2776,  0.3595],
        [ 1.1878, -0.0963,  0.9901,  1.8201, -2.9724,  0.8691,  1.2638,  1.8122],
        [ 1.1247, -0.0911,  0.9322,  1.6982, -2.7962,  0.8095,  1.1833,  1.7063],
        [ 1.3840, -0.1125,  1.1699,  2.1994, -3.5205,  1.0547,  1.5143,  2.1419],
        [-1.1900,  0.6676, -0.0775, -2.2372,  1.8150,  0.0861, -0.7011, -0.4644]])
Target (first 5): tensor([[0., 0., 0., 0., 0., 0., 0., 1.],
        [0., 0., 0., 0., 0., 0., 0., 1.],
        [0., 0., 0., 1., 0., 0., 0., 0.],
        [1., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 1., 0., 0., 0.]])

Test set: Average loss: 0.0549, Accuracy: 837/2818 (30%)



  label_tensor = torch.tensor(create_one_hot_encoding(label, self.classes))
  label_tensor = torch.tensor(create_one_hot_encoding(label, self.classes))
1it [00:02,  2.07s/it]



101it [02:51,  1.35s/it]



201it [05:04,  1.32s/it]



301it [07:21,  1.33s/it]



401it [09:37,  1.32s/it]



501it [11:53,  1.43s/it]



601it [14:12,  1.47s/it]



701it [16:30,  1.33s/it]



705it [16:35,  1.41s/it]
  label_tensor = torch.tensor(create_one_hot_encoding(label, self.classes))
  label_tensor = torch.tensor(create_one_hot_encoding(label, self.classes))


Data shape: torch.Size([32, 3, 96, 96])
Target shape: torch.Size([32, 8])
Output shape: torch.Size([32, 8])
Output (first 5): tensor([[-0.5622,  0.3284,  0.6800,  0.1990, -0.6368,  0.4544,  0.2836,  1.0925],
        [ 0.8327, -0.6975,  1.0119,  2.0095, -2.8461,  1.0574,  0.8862,  2.1898],
        [ 0.6112, -0.3610,  0.6201,  0.9626, -1.7080,  0.5515,  0.4855,  1.2062],
        [ 0.6162, -0.3687,  0.6290,  0.9864, -1.7339,  0.5630,  0.4946,  1.2286],
        [-1.9109,  0.9464,  0.5401, -2.1310,  1.9441, -0.7802, -0.2365, -0.5572]])
Target (first 5): tensor([[0., 0., 0., 0., 0., 0., 0., 1.],
        [0., 0., 0., 0., 0., 0., 0., 1.],
        [0., 0., 0., 1., 0., 0., 0., 0.],
        [1., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 1., 0., 0., 0.]])

Test set: Average loss: 0.0535, Accuracy: 1051/2818 (37%)



  label_tensor = torch.tensor(create_one_hot_encoding(label, self.classes))
  label_tensor = torch.tensor(create_one_hot_encoding(label, self.classes))
1it [00:01,  1.68s/it]



101it [02:16,  1.35s/it]



201it [04:27,  1.33s/it]



301it [06:41,  1.27s/it]



401it [08:55,  1.39s/it]



501it [11:12,  1.40s/it]



601it [13:31,  1.54s/it]



701it [15:47,  1.49s/it]



705it [15:53,  1.35s/it]
  label_tensor = torch.tensor(create_one_hot_encoding(label, self.classes))
  label_tensor = torch.tensor(create_one_hot_encoding(label, self.classes))


Data shape: torch.Size([32, 3, 96, 96])
Target shape: torch.Size([32, 8])
Output shape: torch.Size([32, 8])
Output (first 5): tensor([[-0.8661,  0.0917,  0.5829,  1.5090, -1.5955,  0.7063,  0.1646,  2.2138],
        [ 0.5167, -0.8198,  0.6995,  2.2001, -2.8956,  0.8553,  0.8179,  2.3967],
        [ 0.5161, -0.8956,  0.7384,  2.4365, -3.1081,  0.9324,  0.8853,  2.6203],
        [ 0.5183, -0.6011,  0.5874,  1.5180, -2.2827,  0.6329,  0.6235,  1.7516],
        [-2.8129,  0.8954,  0.4646, -1.2983,  2.4873, -1.2160, -1.4659, -0.5250]])
Target (first 5): tensor([[0., 0., 0., 0., 0., 0., 0., 1.],
        [0., 0., 0., 0., 0., 0., 0., 1.],
        [0., 0., 0., 1., 0., 0., 0., 0.],
        [1., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 1., 0., 0., 0.]])

Test set: Average loss: 0.0500, Accuracy: 1179/2818 (42%)



  label_tensor = torch.tensor(create_one_hot_encoding(label, self.classes))
  label_tensor = torch.tensor(create_one_hot_encoding(label, self.classes))
1it [00:35, 35.66s/it]



101it [02:47,  1.43s/it]



201it [05:05,  1.38s/it]



301it [07:19,  1.28s/it]



401it [09:34,  1.28s/it]



501it [11:46,  1.44s/it]



601it [14:01,  1.32s/it]



701it [16:19,  1.37s/it]



705it [16:25,  1.40s/it]
  label_tensor = torch.tensor(create_one_hot_encoding(label, self.classes))
  label_tensor = torch.tensor(create_one_hot_encoding(label, self.classes))


Data shape: torch.Size([32, 3, 96, 96])
Target shape: torch.Size([32, 8])
Output shape: torch.Size([32, 8])
Output (first 5): tensor([[-1.9388, -1.0630, -0.1584,  2.9277, -2.6648,  1.7166,  0.8760,  3.8602],
        [ 0.2539, -1.4753,  0.2201,  2.4211, -2.8178,  0.9750,  0.9131,  2.3756],
        [-0.0261, -2.4696,  0.1230,  4.4293, -4.2767,  1.6341,  1.4304,  4.1047],
        [ 0.2102, -1.6302,  0.2049,  2.7340, -3.0451,  1.0776,  0.9937,  2.6450],
        [-1.4604,  0.8863,  0.1670, -1.8307,  1.2861, -0.4253, -0.7906, -0.1874]])
Target (first 5): tensor([[0., 0., 0., 0., 0., 0., 0., 1.],
        [0., 0., 0., 0., 0., 0., 0., 1.],
        [0., 0., 0., 1., 0., 0., 0., 0.],
        [1., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 1., 0., 0., 0.]])

Test set: Average loss: 0.0483, Accuracy: 1218/2818 (43%)



  label_tensor = torch.tensor(create_one_hot_encoding(label, self.classes))
  label_tensor = torch.tensor(create_one_hot_encoding(label, self.classes))
1it [00:01,  1.61s/it]



101it [02:17,  1.54s/it]



201it [04:36,  1.42s/it]



301it [06:56,  1.29s/it]



401it [09:14,  1.38s/it]



501it [11:35,  1.32s/it]



601it [13:53,  1.38s/it]



701it [16:45,  1.25s/it]



705it [16:50,  1.43s/it]
  label_tensor = torch.tensor(create_one_hot_encoding(label, self.classes))
  label_tensor = torch.tensor(create_one_hot_encoding(label, self.classes))


Data shape: torch.Size([32, 3, 96, 96])
Target shape: torch.Size([32, 8])
Output shape: torch.Size([32, 8])
Output (first 5): tensor([[-1.5662, -1.8157, -0.1471,  3.1828, -2.9756,  1.2129,  0.7687,  4.0949],
        [-0.3933, -2.3529,  0.0822,  3.7342, -3.5603,  0.9975,  1.1140,  3.7992],
        [-0.4221, -2.3739,  0.0750,  3.7798, -3.5865,  1.0120,  1.1203,  3.8519],
        [-0.0369, -1.6369,  0.1927,  2.3843, -2.6963,  0.6953,  0.8397,  2.5420],
        [-2.1939,  0.9761,  0.4000, -1.9673,  1.7181, -1.3113, -0.9177, -0.5327]])
Target (first 5): tensor([[0., 0., 0., 0., 0., 0., 0., 1.],
        [0., 0., 0., 0., 0., 0., 0., 1.],
        [0., 0., 0., 1., 0., 0., 0., 0.],
        [1., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 1., 0., 0., 0.]])

Test set: Average loss: 0.0469, Accuracy: 1242/2818 (44%)



  label_tensor = torch.tensor(create_one_hot_encoding(label, self.classes))
  label_tensor = torch.tensor(create_one_hot_encoding(label, self.classes))
1it [00:01,  1.58s/it]



101it [02:21,  1.35s/it]



201it [04:41,  1.30s/it]



301it [07:04,  1.44s/it]



401it [09:26,  1.31s/it]



501it [11:48,  1.62s/it]



601it [14:11,  1.40s/it]



701it [16:27,  1.34s/it]



705it [16:33,  1.41s/it]
  label_tensor = torch.tensor(create_one_hot_encoding(label, self.classes))
  label_tensor = torch.tensor(create_one_hot_encoding(label, self.classes))


Data shape: torch.Size([32, 3, 96, 96])
Target shape: torch.Size([32, 8])
Output shape: torch.Size([32, 8])
Output (first 5): tensor([[-1.9290, -2.0310, -1.0342,  3.5581, -2.7979,  1.0756,  0.3150,  3.5670],
        [-0.2963, -2.2836, -0.3289,  3.3672, -3.0838,  0.6503,  0.7885,  2.8143],
        [-0.4879, -2.6608, -0.4729,  4.0551, -3.4632,  0.7565,  0.8736,  3.3430],
        [ 0.0621, -1.5781, -0.0596,  2.0806, -2.3741,  0.4516,  0.6292,  1.8255],
        [-3.7604,  1.0551,  0.1886, -2.6393,  2.3739, -1.4859, -1.3443, -0.4750]])
Target (first 5): tensor([[0., 0., 0., 0., 0., 0., 0., 1.],
        [0., 0., 0., 0., 0., 0., 0., 1.],
        [0., 0., 0., 1., 0., 0., 0., 0.],
        [1., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 1., 0., 0., 0.]])

Test set: Average loss: 0.0458, Accuracy: 1228/2818 (44%)



  label_tensor = torch.tensor(create_one_hot_encoding(label, self.classes))
  label_tensor = torch.tensor(create_one_hot_encoding(label, self.classes))
1it [00:01,  1.59s/it]



101it [02:18,  1.61s/it]



201it [04:39,  1.42s/it]



301it [06:59,  1.32s/it]



401it [09:20,  1.33s/it]



501it [11:41,  1.49s/it]



601it [14:35,  1.25s/it]



701it [16:55,  1.42s/it]



705it [17:01,  1.45s/it]
  label_tensor = torch.tensor(create_one_hot_encoding(label, self.classes))
  label_tensor = torch.tensor(create_one_hot_encoding(label, self.classes))


Data shape: torch.Size([32, 3, 96, 96])
Target shape: torch.Size([32, 8])
Output shape: torch.Size([32, 8])
Output (first 5): tensor([[-1.7517e+00, -2.4675e+00, -1.3323e-01,  2.7059e+00, -3.3534e+00,
          7.3123e-01,  3.2985e-01,  3.4536e+00],
        [-1.0766e+00, -2.7117e+00,  1.2273e-01,  3.0520e+00, -3.6559e+00,
          5.7917e-01,  6.0243e-01,  3.3489e+00],
        [-1.7075e+00, -3.5451e+00, -4.4137e-03,  4.2312e+00, -4.5246e+00,
          7.4957e-01,  6.8568e-01,  4.5134e+00],
        [ 3.5591e-01, -8.1923e-01,  4.1148e-01,  3.7404e-01, -1.6830e+00,
          1.9218e-01,  4.1339e-01,  7.0440e-01],
        [-2.2822e+00,  2.6645e-01,  6.9364e-01, -2.2471e+00,  1.3867e+00,
         -2.0865e+00, -7.3459e-01, -1.8427e-01]])
Target (first 5): tensor([[0., 0., 0., 0., 0., 0., 0., 1.],
        [0., 0., 0., 0., 0., 0., 0., 1.],
        [0., 0., 0., 1., 0., 0., 0., 0.],
        [1., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 1., 0., 0., 0.]])

Test set: Average loss: 0.04

  label_tensor = torch.tensor(create_one_hot_encoding(label, self.classes))
  label_tensor = torch.tensor(create_one_hot_encoding(label, self.classes))
1it [00:01,  1.62s/it]



101it [02:20,  1.61s/it]



201it [04:37,  1.34s/it]



301it [06:58,  1.43s/it]



401it [09:25,  1.37s/it]



501it [11:50,  1.45s/it]



601it [14:15,  1.46s/it]



701it [16:40,  1.45s/it]



705it [16:46,  1.43s/it]
  label_tensor = torch.tensor(create_one_hot_encoding(label, self.classes))
  label_tensor = torch.tensor(create_one_hot_encoding(label, self.classes))


Data shape: torch.Size([32, 3, 96, 96])
Target shape: torch.Size([32, 8])
Output shape: torch.Size([32, 8])
Output (first 5): tensor([[-3.0827e+00, -3.3537e+00, -7.5164e-01,  2.8034e+00, -3.7316e+00,
          3.8140e-01,  4.4233e-01,  3.9057e+00],
        [-1.0853e+00, -3.2311e+00,  1.1498e-02,  3.0240e+00, -3.7688e+00,
          1.2428e-01,  7.2732e-01,  3.0199e+00],
        [-1.6292e+00, -4.0725e+00, -1.3649e-01,  4.0088e+00, -4.5070e+00,
          1.3432e-01,  8.2352e-01,  3.9149e+00],
        [-8.9410e-01, -2.9354e+00,  6.3521e-02,  2.6778e+00, -3.5093e+00,
          1.2076e-01,  6.9350e-01,  2.7053e+00],
        [-4.9838e+00,  1.3625e+00,  2.6074e-03, -3.4854e+00,  3.0293e+00,
         -3.4396e+00, -1.1868e+00, -5.6060e-01]])
Target (first 5): tensor([[0., 0., 0., 0., 0., 0., 0., 1.],
        [0., 0., 0., 0., 0., 0., 0., 1.],
        [0., 0., 0., 1., 0., 0., 0., 0.],
        [1., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 1., 0., 0., 0.]])

Test set: Average loss: 0.04

  label_tensor = torch.tensor(create_one_hot_encoding(label, self.classes))
  label_tensor = torch.tensor(create_one_hot_encoding(label, self.classes))
1it [00:02,  2.03s/it]



101it [02:27,  1.39s/it]



201it [04:53,  1.34s/it]



301it [07:18,  1.39s/it]



401it [09:38,  1.45s/it]



501it [12:32,  1.63s/it]



601it [14:51,  1.38s/it]



701it [17:13,  1.51s/it]



705it [17:19,  1.47s/it]
  label_tensor = torch.tensor(create_one_hot_encoding(label, self.classes))
  label_tensor = torch.tensor(create_one_hot_encoding(label, self.classes))


Data shape: torch.Size([32, 3, 96, 96])
Target shape: torch.Size([32, 8])
Output shape: torch.Size([32, 8])
Output (first 5): tensor([[-3.0439, -5.0921, -0.9990,  2.9183, -3.8426,  1.5390, -0.0827,  3.8359],
        [-1.7091, -5.6480, -0.3927,  4.1650, -4.5684,  1.1656,  0.4794,  3.8816],
        [-1.6520, -5.5303, -0.3717,  4.0633, -4.4947,  1.1420,  0.4782,  3.7922],
        [ 0.1515, -1.8115,  0.2933,  0.8518, -2.1661,  0.3968,  0.4395,  0.9688],
        [-6.3030,  1.9269, -0.1349, -4.9553,  4.2518, -4.4918, -3.0143, -0.9516]])
Target (first 5): tensor([[0., 0., 0., 0., 0., 0., 0., 1.],
        [0., 0., 0., 0., 0., 0., 0., 1.],
        [0., 0., 0., 1., 0., 0., 0., 0.],
        [1., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 1., 0., 0., 0.]])

Test set: Average loss: 0.0436, Accuracy: 1360/2818 (48%)

