# Classification model:


## Dataloader (PATHS SHOULD BE FIXED)

In [None]:
import glob
import os
import torch
from PIL import Image
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
import torchvision.transforms as transforms
import torchvision
import numpy as np
from skimage.io import imread
from skimage.transform import resize
import sys
from tqdm import tqdm


sys.path.append(os.path.join('.', '..')) # Allow us to import shared custom 
                                         # libraries, like utils.py

image_paths = glob.glob('../../../extracted')
len(image_paths)

fishies = {"p virens" : 0, "g morhua": 1, "h lanceolatus" : 2, "background" : 3}

class Fishy(torch.utils.data.Dataset):
    """
    Description of fishy class

    Attributes:
        train : Percentage of set used for training.
        transform : 
        data_path : Path to images
    """

    def __init__(self, train, transform, data_path='extracted', category_path="all_fish.txt"):
        """
        Constructor for Fishy class

            Parameters:
                train : Percentage of set used for training.
                transform : 
                data_path : Path to images
        """
        self.transform = transform
        #data_path = os.path.join(data_path, 'train' if train else 'test')
        fp = open(category_path, 'r')
        #i.split(";")[1][1:-1]

        #self.fish_dict =  {i.split(";")[0] : fishies[i.split(";")[1][1:-1]] for i in fp}
        self.fish_dict = {}
        count_categories = 0
        for i in fp:
            if self.fish_dict.get(fishies[i.split(";")[1][1:-1]]):
                self.fish_dict[fishies[i.split(";")[1][1:-1]]].append(i.split(";")[0])
            else:
                self.fish_dict[fishies[i.split(";")[1][1:-1]]] = [i.split(";")[0]]
                count_categories+=1
        
        self.mostPicturesSameCat = max([len(v) for k,v in self.fish_dict.items()])
        self.lengthOfArray = self.mostPicturesSameCat * count_categories
        #self.name_to_label = [i.split(";")[1][1:-1] for i in fp]
        self.image_paths = glob.glob(data_path + '/*.jpg')
        
    def __len__(self):
        """
        Returns the total number of samples
        Returns :
            int : The total number of images
        """
        return self.lengthOfArray

    def __getitem__(self, idx):
        """
        Generates one sample of data

        Parameters:
            idx (int): Index for image

        Returns :
            Image : Transformed image.
        """

        category = idx//self.mostPicturesSameCat

        pictures = self.fish_dict[category]
        idLookup = idx - self.mostPicturesSameCat * category
        picture = pictures[idLookup % len(pictures)]

        image_path = "extracted/" + picture + ".jpg" #self.image_paths[idx] 
        
        #lookup = image_path.split("/")[-1].split(".")[0]
        
        image = Image.open(image_path)
        #y = self.name_to_label[idx]
        #y = self.fish_dict[lookup]
        X = self.transform(image)
        return X,category

def createDataLoaders(batch_size, size,rotation = 45, train_distribution = 0.8):
    #For testing
    transform = transforms.Compose(
        [transforms.Resize((size,size)),
        transforms.RandomRotation(rotation),
        transforms.RandomVerticalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5),
                            (0.5, 0.5, 0.5))]
    )
    full_dataset = Fishy(train=True, transform=transform)
    train_size = int(train_distribution * len(full_dataset))
    test_size = len(full_dataset) - train_size
    trainset, testset = torch.utils.data.random_split(full_dataset, [train_size, test_size])
    #trainset = Fishy(train=True, transform=transform)
    #testset = Fishy(train=False, transform=transform)
    train_loader = DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=6)
    test_loader = DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=6)

    return train_loader, test_loader, trainset, testset

## Network declaration:

In [None]:
from torch.utils.data import DataLoader
'''
height, width = 512, 512
num_classes   = 4

channels        = 3        
kernel_size     = 3
conv_stride     = 1
conv_pad        = 1
conv_drop_rate  = 0.4
'''


class Net(nn.Module):
    def __init__(self, channels, kernel_size,conv_stride, conv_pad, conv_drop_rate, num_classes, image_size):
        super(Net, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels = channels, #3 channels
                out_channels = 16, 
                kernel_size = kernel_size, 
                stride = conv_stride, 
                padding = conv_pad),

            nn.BatchNorm2d(num_features = 16),

            nn.MaxPool2d(2),

            nn.ReLU(),

            nn.Dropout2d(p=conv_drop_rate),

            nn.Conv2d(in_channels = 16, 
                out_channels = 32, 
                kernel_size = kernel_size, 
                stride = conv_stride, 
                padding = conv_pad),

            nn.BatchNorm2d(num_features = 32),

            nn.ReLU(),

            nn.Dropout2d(p=conv_drop_rate),

            nn.Conv2d(in_channels = 32, 
                out_channels = 64, 
                kernel_size = kernel_size, 
                stride = conv_stride, 
                padding = conv_pad)

        )

        self.fc = nn.Sequential(
            nn.Linear(in_features = 64 * (image_size//2) * (image_size//2), out_features = 256, bias = True),
            
            #nn.BatchNorm1d(256),
            
            nn.ReLU(),

            nn.Linear(in_features = 256, out_features = num_classes, bias = False),

            nn.Softmax(dim = 1)
        )

    def forward(self, x):
        x_img = self.conv(x)
        x_img = x_img.view(x_img.shape[0],-1)
        
        out = self.fc(x_img)
        return out

## Training and evaluating network

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

print("Using device: " + device)

height, width = 512, 512
num_classes   = 4

channels        = 3        
kernel_size     = 3
conv_stride     = 1
conv_pad        = 1
conv_drop_rate  = 0.4
image_size      = 64

# Writer will output to ./runs/ directory by default
writer = SummaryWriter()

network = Net(channels, kernel_size,conv_stride, conv_pad, conv_drop_rate, num_classes, image_size)
print(network)
network.to(device)
LEARNING_RATE = 0.00001
criterion = nn.CrossEntropyLoss()

# weight_decay is equal to L2 regularization
optimizer = optim.Adam(network.parameters(), lr=LEARNING_RATE)

#training loop:

num_epoch = 50
batch_size = 64

trainingloader, testloader, trainset, testset = createDataLoaders(batch_size, image_size)
train_acc_all = []
test_acc_all = []
try:
    for epoch in tqdm(range(num_epoch)):
        running_loss = 0.0
        train_correct = 0
        for i, data in enumerate(trainingloader, 0):

            inputs, labels = data
            inputs, labels = Variable(inputs).to(device), Variable(labels).to(device)

            optimizer.zero_grad()

            optimizer

            outputs = network(inputs)

            loss = criterion(outputs, labels)
            loss.backward()

            optimizer.step()

            running_loss += loss

            predicted = outputs.argmax(1)
            train_correct += (labels==predicted).sum().cpu().item()
            
            if i % 10 == 1:
                print("[%d, %5d] loss: %.3f" % 
                    (epoch + 1, i + 1, running_loss/1000))
                running_loss = 0.0
        
        network.eval()
        test_correct = 0.0
        for data, target in testloader:
            data = data.to(device)
            with torch.no_grad():
                output = network(data)
            predicted = output.argmax(1).cpu()
            test_correct += (target==predicted).sum().item()

        train_acc = train_correct/len(trainset)
        test_acc = test_correct/len(testset)
        train_acc_all.append(train_acc)
        test_acc_all.append(test_acc)
        #writer.add_scalar('Loss/train', train_acc, epoch)
        #writer.add_scalar('Loss/test', np.random.random(), epoch)
        writer.add_scalar('Accuracy/train', train_correct*100//len(trainset), epoch + 1)
        writer.add_scalar('Accuracy/test', test_correct*100//len(testset), epoch + 1)
        if epoch % 3 == 0:
            writer.flush()
        print("Accuracy train: {train:.1f}%\t test: {test:.1f}%".format(test=100*test_acc, train=100*train_acc))
except(KeyboardInterrupt):
    print("Model saved")
    torch.save(network, "model.pt")
    writer.close()
    
print("training over")

#torch.save(network, "model.pt")

print("Saved model")