<a href="https://colab.research.google.com/github/ed21b006/my-public-repo/blob/main/EfficientNet_AdityaRaj_ED21B006.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
fromfrom google.colab import drive
drive.mount('/content/drive')

SyntaxError: ignored

In [None]:
from IPython.display import clear_output

!pip3 install pyprind

clear_output()

In [None]:
# Downloading and Preparing the Dataset

!gdown --id 1oYnD7Izl3LVVzjEMyLxLklX30TKWHgGG
!unzip /content/cifar-10.zip
!rm -rf /content/cifar-10.zip
!mv /content/cifar-10/sample_submission.csv /content/cifar-10/test_labels.csv

clear_output()

In [None]:
# Imports

import torch
import torch.nn as nn
import torchvision
import torch.nn.functional as F
from efficientnet_pytorch import EfficientNet

from PIL import Image

import pandas
import numpy
from sklearn import preprocessing
import matplotlib

import os
import pyprind

PATH = "/content/drive/MyDrive/"

In [2]:
class CreateDataset(torch.utils.data.Dataset):
    def __init__(self, root_dir, mode='train'):
        self.root_dir = root_dir
        self.mode = mode

        self.entry = pandas.read_csv(os.path.join(self.root_dir, f'{self.mode}_labels.csv'))
        self.encoder = self._process_()
        self.entry['label'] = self.encoder.transform(self.entry['label'])

        self.transform = torchvision.transforms.Compose(
            [
                torchvision.transforms.Resize((32,32)),  # LeNet requires input to be of 32x32 pixels 
                torchvision.transforms.ToTensor()
            ]
        )

    def _process_(self):
        data = pandas.read_csv(os.path.join(self.root_dir, 'train_labels.csv'))
        encoder = preprocessing.LabelEncoder()
        encoder.fit(data['label'])
        return encoder

    def __getitem__(self, index):
        data = self.entry.iloc[index]
        image = Image.open(f"/content/cifar-10/train/{data['id']}.png") 
        image = self.transform(image)
        label = data['label']
        return image, label

    def __len__(self):
        return len(self.entry)

NameError: ignored

In [None]:
class Network(torch.nn.Module):
    def __init__(self, model_name='efficientnet-b0', num_classes=10):
        super(Network, self).__init__()
        self.model = EfficientNet.from_pretrained(model_name)
        self.model._fc = nn.Linear(self.model._fc.in_features, num_classes)

    def forward(self, x):
        return self.model(x)

In [None]:
class Trainer():
    def __init__(self, data):

        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        self.trainloader, self.validloader, self.testloader = self.get_iterator(data)
        
        self.model = self.get_model().to(self.device)
        self.criterion = self.get_criterion().to(self.device)
        self.optimizer = self.get_optimizer()

        self.train_loss = []
        self.train_metrics = []
        self.valid_loss = []
        self.valid_metrics = []

        self.epochs = 10

    def get_iterator(self, data):
        train, valid, test = data
        trainloader = torch.utils.data.DataLoader(train, batch_size=64, shuffle=True, drop_last=True) 
        validloader = torch.utils.data.DataLoader(valid, batch_size=64, shuffle=False, drop_last=True) 
        testloader = torch.utils.data.DataLoader(test, batch_size=64, shuffle=False) 
        return trainloader, validloader, testloader

    def get_criterion(self):
        return torch.nn.CrossEntropyLoss() 
    
    def get_optimizer(self):
        return torch.optim.SGD(self.model.parameters(), lr=0.01, momentum=0.9) 

    def get_model(self):
        model = Network() 
        return model

    def save(self, epoch):
        torch.save({
            'model_state_dict': self.model.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
            }, os.path.join(PATH, "model.pth"))
        
    def load(self):
        if os.path.exists(os.path.join(PATH, "model.pth")):
            checkpoints = torch.load(os.path.join(self.args.checkpoint, "model.pth"), map_location=self.device)
            self.model.load_state_dict(checkpoints['model_state_dict'])
            self.optimizer.load_state_dict(checkpoints['optimizer_state_dict'])

    def train(self):
        epoch_loss = 0
        epoch_metrics = {}

        self.model.train()

        with torch.autograd.set_detect_anomaly(True):
            bar = pyprind.ProgBar(len(self.trainloader), bar_char='█')
            for index, (image, label) in enumerate(self.trainloader):  # for batches. 1 loop 1 batch. [total/64] iterations
                image = image.to(self.device)
                label = label.to(self.device)

                self.optimizer.zero_grad()
                
                output = self.model(image) 

                loss = self.criterion(output,label) 

                loss.backward() 
                epoch_loss += loss

                self.optimizer.step()
                bar.update()
        epoch_loss /= len(self.trainloader)

        return epoch_loss, epoch_metrics

    def evaluate(self):
        epoch_loss = 0
        epoch_metrics = {}

        with torch.autograd.set_detect_anomaly(True):
            bar = pyprind.ProgBar(len(self.validloader), bar_char='█')
            for index, (image, label) in enumerate(self.validloader):
                image = image.to(self.device)
                label = label.to(self.device)
                
                output = self.model(image) 

                loss = self.criterion(output,label) 
                epoch_loss += loss

                bar.update()
        epoch_loss /= len(self.validloader)

        return epoch_loss, epoch_metrics

    def test(self):

        self.model.eval()

        outputs = torch.empty([0,])

        with torch.autograd.set_detect_anomaly(True):
            bar = pyprind.ProgBar(len(self.testloader), bar_char='█')
            for index, (image, label) in enumerate(self.testloader):
                image = image.to(self.device)
                label = label.to(self.device)
                
                output = self.model(image) 
                outputs = torch.cat((outputs, output), dim=0)

                bar.update()

        return outputs
    
    def fit(self):
        # epochs=10
        for epoch in range(1, self.epochs+1, 1):

            epoch_train_loss, epoch_train_metrics = self.train()

            self.train_loss.append(epoch_train_loss)
            self.train_metrics.append(epoch_train_metrics)

            epoch_valid_loss, epoch_valid_metrics = self.evaluate()
            
            self.valid_loss.append(epoch_valid_loss)
            self.valid_metrics.append(epoch_valid_metrics) 

            print(f'Epoch {epoch}/{self.epochs+1}: Train Loss = {epoch_train_loss} | Validation Loss = {epoch_valid_loss}')

            # if epoch_valid_metrics >= max(self.valid_metrics):
            if epoch_valid_loss<=min(self.valid_loss):
                self.save(epoch)

In [None]:
train_data = CreateDataset(root_dir="/content/cifar-10/", mode="train")
train_data, valid_data = torch.utils.data.random_split(train_data, [len(train_data)-len(train_data)//10, len(train_data)//10])
test_data = CreateDataset(root_dir="/content/cifar-10/", mode="test")
data = (train_data, valid_data, test_data)

trainer = Trainer(data)
trainer.fit()

outputs = trainer.test()