In [1]:
import torch
from torch import optim
import torchvision
from torchvision import models
from torchvision.transforms import v2
import torch.nn as nn
from PIL import Image
import pandas as pd
print(torchvision.__version__)
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm

0.18.0+cu121


In [2]:
classes = ('anger', 'contempt', 'disgust', 'fear', 'happy', 'neutral', 'sad', 'surprise')

In [3]:
def preprocess(image):
    transform = v2.Compose([
        v2.ToImage() ,
        v2.ToDtype(torch.uint8, scale=True),
        v2.CenterCrop((96, 96)),
        v2.ToTensor(),
        v2.ToDtype(torch.float32, scale=True),
        v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])
    image = transform(image)
    return image


In [4]:
def create_one_hot_encoding(label, classes):
    one_hot = torch.zeros(len(classes), dtype=torch.float32)
    
    if label in classes:
        index = classes.index(label)
        one_hot[index] = 1.0
    return one_hot

In [5]:
      
class CustomFERDataset(Dataset):
    def __init__(self, image_parent_directory, data_directory, transform=None):
        self.image_parent_directory = image_parent_directory
        self.df = pd.read_csv(data_directory)
        self.transform = transform
        self.classes = classes

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        img_name = self.image_parent_directory + self.df.iloc[idx, 0]
        image = Image.open(img_name).convert('RGB')
        
        if self.transform:
            image = self.transform(image)
        
        label = self.df.iloc[idx, 1]
        label_tensor = torch.tensor(create_one_hot_encoding(label, self.classes))
        
        return image, label_tensor

In [6]:
train_data = CustomFERDataset("data/archive/" , "data/train.csv" , transform=preprocess)
train_loader = DataLoader(train_data, batch_size=32, shuffle=True, num_workers=2)

test_data = CustomFERDataset("data/archive/" , "data/test.csv" , transform=preprocess)
test_loader = DataLoader(test_data, batch_size=32, shuffle=False, num_workers=2)

In [7]:
def train(model, train_loader, criterion, optimizer, epoch):
    model.train()
    for batch_idx, (data, target) in tqdm(enumerate(train_loader), total=len(train_loader)):
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        if batch_idx % 100 == 0:
            print(f'Train Epoch: {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)} '
                  f'({100. * batch_idx / len(train_loader):.0f}%)]\tLoss: {loss.item():.6f}')

In [13]:
def test(model, test_loader, criterion):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data, target
            output = model(data)
            
            test_loss += criterion(output, target).item()  # sum up batch loss
            
            # Convert one-hot encoded target to class labels if needed
            if target.ndim > 1:
                _, target = target.max(1)  # convert one-hot encoded target to class labels
                
            # Calculate correct predictions
            _, pred = output.max(1)  # get the index of the max log-probability
            correct += pred.eq(target).sum().item()
            
    test_loss /= len(test_loader.dataset)
    accuracy = 100. * correct / len(test_loader.dataset)
    print(f'\nTest set: Average loss: {test_loss:.4f}, Accuracy: {correct}/{len(test_loader.dataset)} '
          f'({accuracy:.2f}%)\n')


In [9]:
class SimpleResNet(nn.Module):
    def __init__(self, num_classes=8):
        super(SimpleResNet, self).__init__()
        self.model = models.resnet18(pretrained=False)
        self.model.fc = nn.Linear(self.model.fc.in_features, num_classes)  # Modify the last layer for CIFAR-10

    def forward(self, x):
        return self.model(x)

In [10]:
model = SimpleResNet()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)



In [11]:
num_epochs = 10
for epoch in range(1, num_epochs + 1):
    train(model, train_loader, criterion, optimizer, epoch)
    test(model, test_loader, criterion)

# Save the model
torch.save(model.state_dict(), "simple_resnet.pth")

  label_tensor = torch.tensor(create_one_hot_encoding(label, self.classes))
  label_tensor = torch.tensor(create_one_hot_encoding(label, self.classes))
  0%|          | 1/705 [00:00<11:36,  1.01it/s]



 14%|█▍        | 101/705 [01:05<07:11,  1.40it/s]



 29%|██▊       | 201/705 [01:56<03:57,  2.12it/s]



 43%|████▎     | 301/705 [02:46<03:29,  1.93it/s]



 57%|█████▋    | 401/705 [03:36<02:44,  1.85it/s]



 71%|███████   | 501/705 [04:31<02:05,  1.62it/s]



 85%|████████▌ | 601/705 [05:22<00:59,  1.75it/s]



 99%|█████████▉| 701/705 [06:11<00:01,  2.10it/s]



100%|██████████| 705/705 [06:13<00:00,  1.89it/s]
  label_tensor = torch.tensor(create_one_hot_encoding(label, self.classes))
  label_tensor = torch.tensor(create_one_hot_encoding(label, self.classes))
  label_tensor = torch.tensor(create_one_hot_encoding(label, self.classes))
  label_tensor = torch.tensor(create_one_hot_encoding(label, self.classes))
  0%|          | 1/705 [00:00<09:39,  1.22it/s]



 14%|█▍        | 101/705 [00:50<04:51,  2.07it/s]



 29%|██▊       | 201/705 [01:39<03:48,  2.20it/s]



 43%|████▎     | 301/705 [02:26<03:05,  2.18it/s]



 57%|█████▋    | 401/705 [03:12<02:17,  2.21it/s]



 71%|███████   | 501/705 [03:57<01:30,  2.24it/s]



 85%|████████▌ | 601/705 [04:43<00:47,  2.17it/s]



 99%|█████████▉| 701/705 [05:28<00:01,  2.14it/s]



100%|██████████| 705/705 [05:30<00:00,  2.13it/s]
  label_tensor = torch.tensor(create_one_hot_encoding(label, self.classes))
  label_tensor = torch.tensor(create_one_hot_encoding(label, self.classes))
  label_tensor = torch.tensor(create_one_hot_encoding(label, self.classes))
  label_tensor = torch.tensor(create_one_hot_encoding(label, self.classes))
  0%|          | 1/705 [00:00<07:40,  1.53it/s]



 14%|█▍        | 101/705 [00:47<04:34,  2.20it/s]



 29%|██▊       | 201/705 [01:33<04:05,  2.06it/s]



 43%|████▎     | 301/705 [02:20<03:19,  2.02it/s]



 57%|█████▋    | 401/705 [03:07<02:34,  1.96it/s]



 71%|███████   | 501/705 [03:53<01:37,  2.10it/s]



 85%|████████▌ | 601/705 [04:39<00:47,  2.18it/s]



 99%|█████████▉| 701/705 [05:26<00:01,  2.19it/s]



100%|██████████| 705/705 [05:28<00:00,  2.15it/s]
  label_tensor = torch.tensor(create_one_hot_encoding(label, self.classes))
  label_tensor = torch.tensor(create_one_hot_encoding(label, self.classes))
  label_tensor = torch.tensor(create_one_hot_encoding(label, self.classes))
  label_tensor = torch.tensor(create_one_hot_encoding(label, self.classes))
  0%|          | 1/705 [00:00<07:45,  1.51it/s]



 14%|█▍        | 101/705 [00:46<04:39,  2.16it/s]



 29%|██▊       | 201/705 [01:33<03:50,  2.19it/s]



 43%|████▎     | 301/705 [02:19<03:01,  2.23it/s]



 57%|█████▋    | 401/705 [03:07<02:21,  2.15it/s]



 71%|███████   | 501/705 [03:53<01:32,  2.21it/s]



 85%|████████▌ | 601/705 [04:39<00:46,  2.26it/s]



 99%|█████████▉| 701/705 [05:25<00:01,  2.21it/s]



100%|██████████| 705/705 [05:27<00:00,  2.15it/s]
  label_tensor = torch.tensor(create_one_hot_encoding(label, self.classes))
  label_tensor = torch.tensor(create_one_hot_encoding(label, self.classes))
  label_tensor = torch.tensor(create_one_hot_encoding(label, self.classes))
  label_tensor = torch.tensor(create_one_hot_encoding(label, self.classes))
  0%|          | 1/705 [00:00<07:55,  1.48it/s]



 14%|█▍        | 101/705 [00:46<04:30,  2.23it/s]



 29%|██▊       | 201/705 [01:32<03:49,  2.20it/s]



 43%|████▎     | 301/705 [02:18<03:01,  2.22it/s]



 57%|█████▋    | 401/705 [03:04<02:18,  2.19it/s]



 71%|███████   | 501/705 [03:49<01:32,  2.21it/s]



 85%|████████▌ | 601/705 [04:35<00:47,  2.18it/s]



 99%|█████████▉| 701/705 [05:21<00:01,  2.15it/s]



100%|██████████| 705/705 [05:23<00:00,  2.18it/s]
  label_tensor = torch.tensor(create_one_hot_encoding(label, self.classes))
  label_tensor = torch.tensor(create_one_hot_encoding(label, self.classes))
  label_tensor = torch.tensor(create_one_hot_encoding(label, self.classes))
  label_tensor = torch.tensor(create_one_hot_encoding(label, self.classes))
  0%|          | 1/705 [00:00<08:13,  1.43it/s]



 14%|█▍        | 101/705 [00:47<04:38,  2.17it/s]



 29%|██▊       | 201/705 [01:34<03:51,  2.17it/s]



 43%|████▎     | 301/705 [02:21<03:05,  2.17it/s]



 57%|█████▋    | 401/705 [03:08<02:26,  2.07it/s]



 71%|███████   | 501/705 [03:55<01:35,  2.15it/s]



 85%|████████▌ | 601/705 [04:41<00:46,  2.25it/s]



 99%|█████████▉| 701/705 [05:29<00:01,  2.12it/s]



100%|██████████| 705/705 [05:30<00:00,  2.13it/s]
  label_tensor = torch.tensor(create_one_hot_encoding(label, self.classes))
  label_tensor = torch.tensor(create_one_hot_encoding(label, self.classes))
  label_tensor = torch.tensor(create_one_hot_encoding(label, self.classes))
  label_tensor = torch.tensor(create_one_hot_encoding(label, self.classes))
  0%|          | 1/705 [00:00<07:52,  1.49it/s]



 14%|█▍        | 101/705 [00:46<04:42,  2.14it/s]



 29%|██▊       | 201/705 [01:32<04:01,  2.09it/s]



 43%|████▎     | 301/705 [02:18<03:02,  2.21it/s]



 57%|█████▋    | 401/705 [6:15:35<01:46,  2.86it/s]      



 71%|███████   | 501/705 [6:16:14<01:29,  2.27it/s]



 85%|████████▌ | 601/705 [6:16:52<00:37,  2.77it/s]



 99%|█████████▉| 701/705 [6:17:31<00:01,  2.67it/s]



100%|██████████| 705/705 [6:17:33<00:00, 32.13s/it]
  label_tensor = torch.tensor(create_one_hot_encoding(label, self.classes))
  label_tensor = torch.tensor(create_one_hot_encoding(label, self.classes))
  label_tensor = torch.tensor(create_one_hot_encoding(label, self.classes))
  label_tensor = torch.tensor(create_one_hot_encoding(label, self.classes))
  0%|          | 1/705 [00:00<06:53,  1.70it/s]



 14%|█▍        | 101/705 [00:42<04:51,  2.08it/s]



 29%|██▊       | 201/705 [01:20<03:16,  2.56it/s]



 43%|████▎     | 301/705 [02:00<02:44,  2.45it/s]



 57%|█████▋    | 401/705 [02:39<01:55,  2.64it/s]



 71%|███████   | 501/705 [03:20<01:31,  2.22it/s]



 85%|████████▌ | 601/705 [04:01<00:43,  2.40it/s]



 99%|█████████▉| 701/705 [04:44<00:01,  2.54it/s]



100%|██████████| 705/705 [04:45<00:00,  2.47it/s]
  label_tensor = torch.tensor(create_one_hot_encoding(label, self.classes))
  label_tensor = torch.tensor(create_one_hot_encoding(label, self.classes))
  label_tensor = torch.tensor(create_one_hot_encoding(label, self.classes))
  label_tensor = torch.tensor(create_one_hot_encoding(label, self.classes))
  0%|          | 1/705 [00:00<07:29,  1.56it/s]



 14%|█▍        | 101/705 [00:43<04:14,  2.37it/s]



 29%|██▊       | 201/705 [01:23<03:13,  2.61it/s]



 43%|████▎     | 301/705 [02:02<02:38,  2.55it/s]



 57%|█████▋    | 401/705 [02:41<01:58,  2.57it/s]



 71%|███████   | 501/705 [03:21<01:20,  2.53it/s]



 85%|████████▌ | 601/705 [04:01<00:41,  2.51it/s]



 99%|█████████▉| 701/705 [04:41<00:01,  2.47it/s]



100%|██████████| 705/705 [04:43<00:00,  2.49it/s]
  label_tensor = torch.tensor(create_one_hot_encoding(label, self.classes))
  label_tensor = torch.tensor(create_one_hot_encoding(label, self.classes))
  label_tensor = torch.tensor(create_one_hot_encoding(label, self.classes))
  label_tensor = torch.tensor(create_one_hot_encoding(label, self.classes))
  0%|          | 1/705 [00:00<06:47,  1.73it/s]



 14%|█▍        | 101/705 [00:40<04:03,  2.48it/s]



 29%|██▊       | 201/705 [01:22<03:18,  2.54it/s]



 43%|████▎     | 301/705 [02:03<02:42,  2.48it/s]



 57%|█████▋    | 401/705 [02:43<02:06,  2.41it/s]



 71%|███████   | 501/705 [03:24<01:25,  2.37it/s]



 85%|████████▌ | 601/705 [04:05<00:42,  2.45it/s]



 99%|█████████▉| 701/705 [04:46<00:01,  2.51it/s]



100%|██████████| 705/705 [04:48<00:00,  2.45it/s]
  label_tensor = torch.tensor(create_one_hot_encoding(label, self.classes))
  label_tensor = torch.tensor(create_one_hot_encoding(label, self.classes))


In [14]:
test_model = SimpleResNet()
test_model.load_state_dict(torch.load("simple_resnet.pth"))
criterion = nn.CrossEntropyLoss()
test(test_model, test_loader, criterion)


  label_tensor = torch.tensor(create_one_hot_encoding(label, self.classes))
  label_tensor = torch.tensor(create_one_hot_encoding(label, self.classes))



Test set: Average loss: 0.0513, Accuracy: 1744/2818 (61.89%)

