In [24]:
import torch
import torch.nn.functional as F
import torch.nn as nn
from PIL import Image
import csv
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torch.utils.tensorboard import SummaryWriter
import argparse

In [3]:
LABEL_NAMES = ['background', 'kart', 'pickup', 'nitro', 'bomb', 'projectile']


class SuperTuxDataset(Dataset):
    def __init__(self, dataset_path):
        """
        Your code here
        Hint: Use the python csv library to parse labels.csv

        WARNING: Do not perform data normalization here. 
        """
        self.images = []
        self.labels = []
        self.transform = transforms.ToTensor()
        self.complte_path = dataset_path + "/labels.csv"
        with open(self.complte_path) as csv_file:
            reader = csv.reader(csv_file, delimiter=",")

            next(reader)

            for row in reader:
                self.images.append(dataset_path + "/" + row[0])
                self.labels.append(LABEL_NAMES.index(row[1]))
        

    def __len__(self):
        """
        Your code here
        """
        return len(self.images)

    def __getitem__(self, idx):
        """
        Your code here
        return a tuple: img, label
        """
        image = Image.open(self.images[idx])
        image_tensor = self.transform(image)
        return image_tensor, self.labels[idx]


def load_data(dataset_path, num_workers=0, batch_size=128):
    dataset = SuperTuxDataset(dataset_path)
    return DataLoader(dataset, num_workers=num_workers, batch_size=batch_size, shuffle=True, drop_last=False)


def accuracy(outputs, labels):
    outputs_idx = outputs.max(1)[1].type_as(labels)
    return outputs_idx.eq(labels).float().mean()

In [4]:
train_loader = load_data("./data/train")

In [18]:
for inputs, labels in train_loader:
    inputs_evaluation = inputs
    labels_evaluation = labels
    break

## Loss Class Function

$-\log\left(\frac{\exp(x_l) }{ \sum_j \exp(x_j)} \right)$

In [14]:
class ClassificationLoss(torch.nn.Module):
    def forward(self, input, target):
        """
        Your code here

        Compute mean(-log(softmax(input)_label))

        @input:  torch.Tensor((B,C))
        @target: torch.Tensor((B,), dtype=torch.int64)

        @return:  torch.Tensor((,))

        Hint: Don't be too fancy, this is a one-liner
        """
        mean_loss = F.cross_entropy(input, target)
        return mean_loss

## Models

In [15]:
class LinearClassifier(torch.nn.Module):
    def __init__(self):
        super().__init__()

        """
        Your code here
        """
        self.linear = nn.Linear(in_features=3*64*64, out_features=6)

    def forward(self, x):
        """
        Your code here

        @x: torch.Tensor((B,3,64,64))
        @return: torch.Tensor((B,6))
        """
        x = x.view(-1, 3*64*64)
        return self.linear(x)


class MLPClassifier(torch.nn.Module):
    def __init__(self):
        super().__init__()

        """
        Your code here
        """
        self.sequential = nn.Sequential(
            nn.Linear(in_features = 3*64*64, out_features = 256),
            nn.ReLU(),
            nn.Linear(in_features = 256, out_features = 128),
            nn.ReLU(),
            nn.Linear(in_features = 128, out_features = 6)
        )

    def forward(self, x):
        """
        Your code here

        @x: torch.Tensor((B,3,64,64))
        @return: torch.Tensor((B,6))
        """
        x = x.view(-1, 3*64*64)
        x = self.sequential(x)
        return x

In [25]:
model_factory = {
    'linear': LinearClassifier,
    'mlp': MLPClassifier,
}


def save_model(model):
    from torch import save
    from os import path
    for n, m in model_factory.items():
        if isinstance(model, m):
            return save(model.state_dict(), path.join(path.dirname(path.abspath(__file__)), '%s.th' % n))
    raise ValueError("model type '%s' not supported!" % str(type(model)))


def load_model(model):
    from torch import load
    from os import path
    r = model_factory[model]()
    r.load_state_dict(load(path.join(path.dirname(path.abspath(__file__)), '%s.th' % model), map_location='cpu'))
    return r


In [16]:
mlp_model = MLPClassifier()
criterion = ClassificationLoss()

### Train

In [29]:
def train(model_type='linear', epochs=10):
    model = model_factory[model_type]()
    train_loader = load_data("./data/train")
    val_loader = load_data("./data/valid")
    optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
    criterion = ClassificationLoss()
    writer = SummaryWriter(log_dir=f'runs/{model_type}_train')

    for epoch in range(epochs):
        model.train()
        total_loss = 0
        count = 0
        for inputs, labels in train_loader:
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

            count += 1
            if count % 50 == 0:
                writer.add_scalar('Training/Loss', total_loss / count, (epoch + 1) * count)
                print(f'[Training] Epoch: {epoch + 1}/{epochs} - Count: {count}: {total_loss/count}')
        model.eval()
        val_loss, val_correct, val_samples = 0, 0, 0
        with torch.no_grad():
            total_loss = 0
            count = 0
            for inputs, labels in val_loader:
                outputs = model(inputs)
                loss = criterion(outputs, labels)
                total_loss += loss.item()

                count += 1
                if val_samples % 100 == 0:
                    writer.add_scalar('Validation/Loss', total_loss / count, (epoch + 1) * count)
                    print(f'[Validation] Epoch: {epoch + 1}/{epochs} - Count: {count}: {total_loss/count}')
    writer.close()

In [30]:
train('mlp', 10)

Epoch: 1/10 - Count: 50: 1.7043973517417907
Epoch: 1/10 - Count: 100: 1.596347371339798
Epoch: 1/10 - Count: 150: 1.4944430581728618
Epoch: 2/10 - Count: 50: 1.109833184480667
Epoch: 2/10 - Count: 100: 1.0563062697649002
Epoch: 2/10 - Count: 150: 1.0112765292326609
Epoch: 3/10 - Count: 50: 0.8452679479122162
Epoch: 3/10 - Count: 100: 0.839406788945198
Epoch: 3/10 - Count: 150: 0.8235836776097616
Epoch: 4/10 - Count: 50: 0.7628085958957672
Epoch: 4/10 - Count: 100: 0.7526772117614746
Epoch: 4/10 - Count: 150: 0.7318068087100983
Epoch: 5/10 - Count: 50: 0.6954132652282715
Epoch: 5/10 - Count: 100: 0.6752057605981827
Epoch: 5/10 - Count: 150: 0.6637463609377543
Epoch: 6/10 - Count: 50: 0.6298600697517395
Epoch: 6/10 - Count: 100: 0.625765155851841
Epoch: 6/10 - Count: 150: 0.6163016454378764
Epoch: 7/10 - Count: 50: 0.5857199186086655
Epoch: 7/10 - Count: 100: 0.5740009704232216
Epoch: 7/10 - Count: 150: 0.5733136488993963
Epoch: 8/10 - Count: 50: 0.5618437474966049
Epoch: 8/10 - Count: 1

## Data Visualization

In [31]:
def visualize_data(args):
    dataset = SuperTuxDataset(args.dataset)

    f, axes = plt.subplots(args.n, len(LABEL_NAMES))

    counts = [0]*len(LABEL_NAMES)

    for img, label in dataset:
        c = counts[label]
        if c < args.n:
            ax = axes[c][label]
            ax.imshow(img.permute(1, 2, 0).numpy())
            ax.axis('off')
            ax.set_title(LABEL_NAMES[label])
            counts[label] += 1
        if sum(counts) >= args.n * len(LABEL_NAMES):
            break

    plt.show()