In [1]:
import random
import os
import numpy as np
import torch
from torchvision.transforms.functional import to_tensor, normalize, affine
from PIL import Image
from typing import Tuple, List, NamedTuple
from tqdm import tqdm
import zipfile
from urllib import request
import torch.nn as nn

# Seed all random number generators
np.random.seed(197331)
torch.manual_seed(197331)
random.seed(197331)


class NetworkConfiguration(NamedTuple):
    n_channels: Tuple[int, ...] = (16, 32, 48)
    kernel_sizes: Tuple[int, ...] = (3, 3, 3)
    strides: Tuple[int, ...] = (1, 1, 1)
    dense_hiddens: Tuple[int, ...] = (256, 256)


class Trainer:
    def __init__(self,
                 network_type: str = "mlp",
                 net_config: NetworkConfiguration = NetworkConfiguration(),
                 lr: float = 0.001,
                 batch_size: int = 128,
                 activation_name: str = "relu"):
        self.train, self.test = self.load_dataset()
        self.network_type = network_type
        activation_function = self.create_activation_function(activation_name)
        input_dim = self.train[0].shape[1:]
        if network_type == "mlp":
            self.network = self.create_mlp(input_dim[0]*input_dim[1]*input_dim[2], 
                                           net_config,
                                           activation_function)
        elif network_type == "cnn":
            self.network = self.create_cnn(input_dim[0], 
                                           net_config, 
                                           activation_function)
        else:
            raise ValueError("Network type not supported")
        self.optimizer = torch.optim.Adam(self.network.parameters(), lr=lr)
        self.lr = lr
        self.batch_size = batch_size

        self.train_logs = {'train_loss': [], 'test_loss': [],
                           'train_mae': [], 'test_mae': []}

    @staticmethod
    def load_dataset() -> Tuple[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]:
        if not os.path.exists('./rotated_fashion_mnist'):
            url = 'https://drive.google.com/u/0/uc?id=1NQPmr01eIafQKeH9C9HR0lGuB5z6mhGb&export=download&confirm=t&uuid=645ff20a-d47b-49f0-ac8b-4a7347529c8e&at=AHV7M3d_Da0D7wowJlTzzZxDky5c:1669325231545'
            with request.urlopen(url) as f:
                with open('./rotated_fashion_mnist.zip', 'wb') as out:
                    out.write(f.read())
            with zipfile.ZipFile('./rotated_fashion_mnist.zip', 'r') as zip_ref:
                zip_ref.extractall()
            os.remove('./rotated_fashion_mnist.zip')
            
        datapath = './rotated_fashion_mnist'

        def get_paths_and_rots(split: str) -> List[Tuple[str, float]]:
            image_paths, rots = [], []
            files = os.listdir(os.path.join(datapath, split))
            for file in files:
                image_paths.append(os.path.join(datapath, split, file))
                rots.append(float(file.split('_')[1].split('.')[0]))
            return image_paths, rots
        
        def to_tensors(image_paths: List[str], rots: List[float]) -> Tuple[torch.Tensor, torch.Tensor]:
            images = [normalize(to_tensor(Image.open(path)), (0.5,), (0.5,)) 
                      for path in image_paths]
            images = torch.stack(images)
            labels = torch.tensor(rots).view(-1, 1)
            return images, labels

        X_train, y_train = to_tensors(*get_paths_and_rots('train'))
        X_test, y_test = to_tensors(*get_paths_and_rots('test'))
        
        # Normalize y for easier training
        mean, std = y_train.mean(), y_train.std()
        y_train = (y_train - mean) / std
        y_test = (y_test - mean) / std
        
        return (X_train, y_train), (X_test, y_test)

    @staticmethod
    def create_mlp(input_dim: int, net_config: NetworkConfiguration,
                   activation: torch.nn.Module) -> torch.nn.Module:
        """
        Create a multi-layer perceptron (MLP) network.

        :param net_config: a NetworkConfiguration named tuple. Only the field 'dense_hiddens' will be used.
        :param activation: The activation function to use.
        :return: A PyTorch model implementing the MLP.
        """
        # TODO write code here
        network = []
        print('Pass')
        network.append(nn.Flatten())
        for i in range(len(net_config.dense_hiddens)):
            if i==0:
                network.append(nn.Linear(input_dim, net_config.dense_hiddens[0]))
            else:
                network.append(nn.Linear(net_config.dense_hiddens[i-1], net_config.dense_hiddens[i]))
            network.append(activation)
        network.append(nn.Linear(net_config.dense_hiddens[-1], 1))
        full_network = nn.Sequential(*network)
        return full_network

    @staticmethod
    def create_cnn(in_channels: int, net_config: NetworkConfiguration,
                   activation: torch.nn.Module) -> torch.nn.Module:
        """
        Create a convolutional network.

        :param in_channels: The number of channels in the input image.
        :param net_config: a NetworkConfiguration specifying the architecture of the CNN.
        :param activation: The activation function to use.
        :return: A PyTorch model implementing the CNN.
        """
        network = []
        network.append(nn.Conv2d(in_channels, net_config.n_channels[0], net_config.kernel_sizes[0], stride=net_config.strides[0]))
        network.append(activation)
        for i in range(1, len(net_config.n_channels)):
          network.append(nn.MaxPool2d(kernel_size=2))
          network.append(nn.Conv2d(net_config.n_channels[i-1], net_config.n_channels[i], net_config.kernel_sizes[i], stride=net_config.strides[i]))
          network.append(activation)
        network.append(nn.AdaptiveMaxPool2d((4, 4)))
        network.append(nn.Flatten())
        input_size = 4 * 4 * net_config.n_channels[-1]
        for i in range(len(net_config.dense_hiddens)):
          if i==0:
            network.append(nn.Linear(input_size, net_config.dense_hiddens[0]))
          else:
            network.append(nn.Linear(net_config.dense_hiddens[i-1], net_config.dense_hiddens[i]))
          network.append(activation)
        network.append(nn.Linear(net_config.dense_hiddens[-1], 1))      
        full_network = nn.Sequential(*network)
        return full_network

    @staticmethod
    def create_activation_function(activation_str: str) -> torch.nn.Module:
#         assert activation_str not in ['relu', 'tanh', 'sigmoid']
        if activation_str == 'relu':
            return nn.ReLU()
        elif activation_str == 'tanh':
            return nn.Tanh()
        elif activation_str == 'sigmoid':
            return nn.Sigmoid()
        
    def compute_loss_and_mae(self, X: torch.Tensor, y: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        # TODO WRITE CODE HERE
        outputs = self.network(X)
        calculate_loss = nn.MSELoss() 
        output = calculate_loss(outputs, y)
        calculate_error = nn.L1Loss()
        mae_error = calculate_error(outputs, y)

        return output, mae_error

    def training_step(self, X_batch: torch.Tensor, y_batch: torch.Tensor):
        # TODO WRITE CODE HERE
        self.optimizer.zero_grad()
        loss, _ = self.compute_loss_and_mae(X_batch, y_batch)
        loss.backward()
        self.optimizer.step()

    def log_metrics(self, X_train: torch.Tensor, y_train: torch.Tensor,
                    X_test: torch.Tensor, y_test: torch.Tensor) -> None:
        self.network.eval()
        with torch.inference_mode():
            train_loss, train_mae = self.compute_loss_and_mae(X_train, y_train)
            test_loss, test_mae = self.compute_loss_and_mae(X_test, y_test)
        self.train_logs['train_mae'].append(train_mae.item())
        self.train_logs['test_mae'].append(test_mae.item())
        self.train_logs['train_loss'].append(train_loss.item())
        self.train_logs['test_loss'].append(test_loss.item())

    def train_loop(self, n_epochs: int):
        # Prepare train and validation data
        X_train, y_train = self.train
        X_test, y_test = self.test

        n_batches = int(np.ceil(X_train.shape[0] / self.batch_size))

        self.log_metrics(X_train[:2000], y_train[:2000], X_test, y_test)
        for epoch in tqdm(range(n_epochs)):
            for batch in range(n_batches):
                minibatchX = X_train[self.batch_size * batch:self.batch_size * (batch + 1), :]
                minibatchY = y_train[self.batch_size * batch:self.batch_size * (batch + 1), :]
                self.training_step(minibatchX, minibatchY)
            self.log_metrics(X_train[:2000], y_train[:2000], X_test, y_test)
        return self.train_logs

    def evaluate(self, X: torch.Tensor, y: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        # TODO WRITE CODE HERE
        loss, mae = self.compute_loss_and_mae(X, y)
        loss = loss.detach()

        return loss, mae


    def test_equivariance(self):
        from functools import partial
        test_im = (self.train[0][0] + 1) / 2
        conv = torch.nn.Conv2d(kernel_size=3, in_channels=1, out_channels=1, stride=1, padding=0)
        fullconv_model = lambda x: torch.relu(conv((torch.relu(conv((x))))))
        model = fullconv_model

        shift_amount = 5
        shift = partial(affine, angle=0, translate=(shift_amount, shift_amount), scale=1, shear=0)
        rotation = partial(affine, angle=90, translate=(0, 0), scale=1, shear=0)

        # TODO CODE HERE
        pass

def main():
    net_config = NetworkConfiguration(n_channels = (16, 32, 48),
                                        kernel_sizes = (3, 3, 3),
                                        strides = (1, 1, 1),
                                        dense_hiddens = (128, 128)
                                        )
    trainer_obj = Trainer(network_type = "mlp",
                             net_config = net_config,
                             lr = 0.01,
                             batch_size = 128,
                             activation_name = "relu")
    train_logs = trainer_obj.train_loop(50)
    print(train_logs)


if __name__=="__main__":
    main()

  from .autonotebook import tqdm as notebook_tqdm


AssertionError: 