In [1]:
import torch
import torch.nn as nn
from torchvision.ops import MLP
import pandas as pd
import numpy as np
from tqdm import tqdm

In [2]:
data_train = pd.read_csv("./dataset/fer2013.csv")
data_train.head()

Unnamed: 0,emotion,pixels,Usage
0,0,70 80 82 72 58 58 60 63 54 58 60 48 89 115 121...,Training
1,0,151 150 147 155 148 133 111 140 170 174 182 15...,Training
2,2,231 212 156 164 174 138 161 173 182 200 106 38...,Training
3,4,24 32 36 30 32 23 19 20 30 41 21 22 32 34 21 1...,Training
4,6,4 0 0 0 0 0 0 0 0 0 0 0 3 15 23 28 48 50 58 84...,Training


In [3]:
# make the datset to an array of images of pixels 
image_array =[]
for i, row in enumerate(data_train.index):
        image = np.fromstring(data_train.loc[row, 'pixels'], dtype=int, sep=' ')
        image_array.append(image.flatten())

print(f"{len(image_array)} images in the dataset")
print(f"Each image has {len(image_array[0])} pixels")

lables = np.array(data_train['emotion']).tolist()
lables

print(f"Number of lables: {len(lables)}")
print(f"Number of unique lables: {len(set(lables))}")

35887 images in the dataset
Each image has 2304 pixels
Number of lables: 35887
Number of unique lables: 7


In [4]:
flat_images = np.array(image_array)
target = np.array(lables)

# normalization
flat_images = flat_images / 255

In [5]:
class SimpleMLP(nn.Module):
    def __init__(self, in_channels, hidden_channels, n_classes):
        super(SimpleMLP, self).__init__()
        self.mlp = MLP(in_channels=in_channels,
                       hidden_channels=hidden_channels,
                       norm_layer=None,  # Not using normalization here
                       activation_layer=nn.ReLU,
                       bias=True,
                       dropout=0.5)  # Example dropout for regularization
        self.classifier = nn.Linear(hidden_channels[-1], n_classes)

    def forward(self, x):
        x = self.mlp(x)
        x = self.classifier(x)
        return x

In [6]:
shapes = [
    [64],
    [128],
    [256],
    [512],
    [1024],
    [128, 64],
    [256, 128],
    [512, 256],
    [1024, 512],
    [256, 128, 64],
    [512, 256, 128],
    [1024, 512, 256],
    [512, 256, 128, 64],
    [1024, 512, 256, 128],
    [1024, 512, 256, 128, 64]
]
models = [SimpleMLP(in_channels=flat_images.shape[1], hidden_channels=shape, n_classes=7) for shape in shapes]
optimizers = [torch.optim.Adam(model.parameters(), lr=1e-5, weight_decay=1e-5) for model in models]
criterion = nn.CrossEntropyLoss()
batch_size = 1000

nexp = flat_images.shape[0]

# seperate the first 80% of the dataset for training and the rest for testing
ntrain = int(nexp * 0.8)
ntest = nexp - ntrain
train_images = torch.tensor(flat_images[:ntrain], dtype=torch.float32)
train_target = torch.tensor(target[:ntrain], dtype=torch.int64)
test_images = torch.tensor(flat_images[ntrain:], dtype=torch.float32)
test_target = torch.tensor(target[ntrain:], dtype=torch.int64)

def run_epoch(model, optimizer):
    for i in range(0, train_images.shape[0], batch_size):
        images = train_images[i:i+batch_size]
        labels = train_target[i:i+batch_size]
    
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

def accuracy(model):
    images = test_images
    outputs = model(images)
    _, predicted = torch.max(outputs, 1)
    total = images.size(0)
    correct = (predicted == test_target).sum().item()
    return correct / total

count = 0
while count < 500:
    count += 1
    for i in range(len(models)):
        model = models[i]
        optimizer = optimizers[i]
        run_epoch(model, optimizer)
        
    output = f"Epoch: {count}"
    for i, model in enumerate(models):
        output += f" - {shapes[i]}: {accuracy(model):.4f}"
    print(output)

for i, model in enumerate(models):
    model_path = f"model"
    shape = shapes[i]
    for s in shape:
        model_path += f"_{s}"
    model_path += ".pth"
    torch.save(model.state_dict(), model_path)
        



Epoch: 1 - [64]: 0.2221 - [128]: 0.2360 - [256]: 0.2381 - [512]: 0.2347 - [1024]: 0.2460 - [128, 64]: 0.2179 - [256, 128]: 0.2138 - [512, 256]: 0.2419 - [1024, 512]: 0.2439 - [256, 128, 64]: 0.1506 - [512, 256, 128]: 0.2172 - [1024, 512, 256]: 0.2385 - [512, 256, 128, 64]: 0.1913 - [1024, 512, 256, 128]: 0.1442 - [1024, 512, 256, 128, 64]: 0.1740
Epoch: 2 - [64]: 0.2189 - [128]: 0.2377 - [256]: 0.2396 - [512]: 0.2409 - [1024]: 0.2476 - [128, 64]: 0.2233 - [256, 128]: 0.2243 - [512, 256]: 0.2361 - [1024, 512]: 0.2425 - [256, 128, 64]: 0.1744 - [512, 256, 128]: 0.2324 - [1024, 512, 256]: 0.2386 - [512, 256, 128, 64]: 0.2137 - [1024, 512, 256, 128]: 0.2062 - [1024, 512, 256, 128, 64]: 0.1761
Epoch: 3 - [64]: 0.2293 - [128]: 0.2402 - [256]: 0.2442 - [512]: 0.2460 - [1024]: 0.2549 - [128, 64]: 0.2278 - [256, 128]: 0.2272 - [512, 256]: 0.2373 - [1024, 512]: 0.2434 - [256, 128, 64]: 0.1856 - [512, 256, 128]: 0.2393 - [1024, 512, 256]: 0.2359 - [512, 256, 128, 64]: 0.2244 - [1024, 512, 256, 12

Epoch: 500 - [64]: 0.3639 - [128]: 0.3713 - [256]: 0.3731 - [512]: 0.3725 - [1024]: 0.3812 - [128, 64]: 0.3713 - [256, 128]: 0.3860 - [512, 256]: 0.4062 - [1024, 512]: 0.4281 - [256, 128, 64]: 0.3831 - [512, 256, 128]: 0.4007 - [1024, 512, 256]: 0.4205 - [512, 256, 128, 64]: 0.3856 - [1024, 512, 256, 128]: 0.4113 - [1024, 512, 256, 128, 64]: 0.3982