# EECS759P Coursework 2 (CNN Classification Task)
- Name: Bheki Maenetja
- Student ID: 230382466

## Imports

In [None]:
import torchvision
import torchvision.transforms as transforms
import torch
import torch.nn as nn

# !pip install plotly
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import plotly.io as pio
# pio.renderers.default = "iframe"

## Plotting Functions

In [None]:
# Plotting functions
def plot_data(x=None, y=None, z=None, size=None, colour=None, title="", colour_title="", x_label="", y_label="", name="", mode="markers", text="", fill=None, **traces):
    """
    General purpose function for plotting scatter plots in plotly.
    """
    fig = go.Figure(layout={
        "title": title,
        "xaxis": {"title": x_label},
        "yaxis": {"title": y_label}
    })

    marker = dict()
    
    if size is not None:
        marker["size"] = size
        marker["sizeref"] = 0.01
    if colour is not None:
        marker["color"] = colour
        marker["showscale"] = True
        marker["colorbar"] = dict(title=colour_title)
    
    if z is None:
        data = go.Scatter(
            x=x,
            y=y,
            mode=mode,
            name=name,
            text=text,
            fill=fill,
            marker=marker,
        )
    else:
        data = go.Scatter3d(
            x=x,
            y=y,
            z=z,
            mode=mode,
            name=name,
            text=text,
            marker=marker,
        )

    if x is not None and y is not None:
        fig.add_trace(data)
    
    for t in traces:
        fig.add_trace(traces[t])
    
    return fig

def create_trace(x=None, y=None, z=None, size=None, colour=None, colour_title="", name="", mode="lines", text="", fill=None):
    marker = dict()
    
    if size is not None:
        marker["size"] = size
        marker["sizeref"] = 0.01
    if colour is not None:
        marker["color"] = colour
        marker["showscale"] = True
        marker["colorbar"] = dict(title=colour_title)
    
    if z is None:
        trace = go.Scatter(
            x=x,
            y=y,
            mode=mode,
            name=name,
            text=text,
            fill=fill,
            marker=marker
        )
    else:
        trace = go.Scatter3d(
            x=x,
            y=y,
            z=z,
            mode=mode,
            name=name,
            text=text,
            marker=marker
        )
    
    return trace

def plot_collection(plots, rows=1, cols=1, title="", subplot_titles=[], x_labels={}, height=1000):
    specs = [
        [{"type": "xy"} for c in range(cols)] 
        for r in range(rows)
    ]
    
    fig = make_subplots(
        rows=rows, 
        cols=cols, 
        subplot_titles=subplot_titles,
        specs=specs,
    )
    
    fig.update_layout({
        "title": title,
        "height": height,
    })

    # Add traces
    for k in plots:
        for i in range(len(plots[k].data)):
            fig.add_trace(plots[k].data[i], row=k[0], col=k[1])

    # Update axes
    for k in plots:
        fig.update_xaxes(title_text=x_labels.get(k, ""), row=k[0], col=k[1])

    return fig

## Loading Data

In [None]:
train_set = torchvision.datasets.FashionMNIST(root = ".", train=True, download=True, transform=transforms.ToTensor())
test_set = torchvision.datasets.FashionMNIST(root = ".", train=False, download=True, transform=transforms.ToTensor())
train_loader = torch.utils.data.DataLoader(train_set, batch_size=32, shuffle=True) 
test_loader = torch.utils.data.DataLoader(test_set, batch_size=32, shuffle=False) 
# Fix the seed to be able to get the same randomness across runs and hence reproducible outcomes
torch.manual_seed(0)

## CNN Setup

### FashionCNN Class

In [None]:
def initialise_weights(m):
    if isinstance(m, nn.Linear) or isinstance(m, nn.Conv2d):
        torch.nn.init.xavier_normal_(m.weight)

class FashionCNN(nn.Module):
    def __init__(self):
        super(FashionCNN, self).__init__()
        self.network = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=5, stride=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(32, 64, kernel_size=5, stride=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Flatten(),
            nn.Linear(1024, 1024),
            nn.ReLU(),
            nn.Linear(1024, 256),
            nn.ReLU(),
            nn.Linear(256, 10)
        )

        self.network.apply(initialise_weights)

    def forward(self, x):
        return self.network(x)

### Evaluation

In [None]:
def evaluation(model, dataloader):
    total, correct = 0,0
    # turn on evaluate mode, this de-activates certain modes such as dropout
    # good practice to include in your projects
    model.eval()
    for data in dataloader:
        inputs, labels = data
        # inputs, labels = inputs.to(device), labels.to(device)
        outputs = model(inputs)
        # we take the index of the class that received the highest value
        # we take outputs.data so that no backpropagation is performed for these outputs
        _, pred = torch.max(outputs.data, 1)
        total += labels.size(0)
        # .item() takes Python float values from the tensor
        correct += (pred == labels).sum().item()
    return 100 * correct / total

### Training Function

In [None]:
def train_model(model, train_loader, test_loader, alpha=0.1, max_epochs=30):
    loss_fn = nn.CrossEntropyLoss()
    opt = torch.optim.SGD(list(model.parameters()), lr=alpha)

    loss_per_epoch = []
    epoch_sum = 0
    train_acc = []
    test_acc = []
    
    for e in range(max_epochs):
        epoch_sum = 0
        
        for i, data in enumerate(train_loader, 0):
            model.train()
            inputs, labels = data
            # inputs, labels = inputs.to(device), labels.to(device)
            # zero the gradients
            opt.zero_grad()
            outputs = model(inputs)
            # compute the loss
            loss = loss_fn(outputs, labels)
            # calculate the gradients
            loss.backward()
            # update the parameters using the gradients and optimizer algorithm
            opt.step()
            # we sum the loss over batches
            epoch_sum += loss.item()

        loss_per_epoch.append(epoch_sum)
        train_acc.append(evaluation(model, train_loader))
        test_acc.append(evaluation(model, test_loader))
        print(f"Epoch {e+1} | Avg Loss: {loss_per_epoch[-1]} | Train accuracy: {train_acc[-1]} | Test accuracy: {test_acc[-1]}")
    
    return loss_per_epoch, train_acc, test_acc

In [None]:
my_cnn = FashionCNN()

In [None]:
train_model(my_cnn, train_loader, test_loader)