In [11]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt

In [12]:
input_path = "../input/a-large-scale-fish-dataset/Fish_Dataset/Fish_Dataset"

In [48]:
import os
from PIL import Image

class Dataset(Dataset):
    def __init__(self, input_path=input_path):
        self.class_list = [x for x in os.listdir(input_path) if "." not in x]
        self.classes = []
        self.images = []
        
        for i, class_name in enumerate(self.class_list):
            sub_path = os.path.join(input_path, class_name)
            print(i, class_name)
            for j in os.listdir(os.path.join(sub_path,class_name)):
                self.images.append(os.path.join(sub_path,class_name,j))
                self.classes.append(i)
            
    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, index):
        im = Image.open(self.images[index])
        
        return im, self.classes[index] 
    
    def return_class(self, n):
        return self.class_list[n]

In [49]:
train_dataset = Dataset()

In [89]:
X, y = train_dataset[1000]

In [90]:
plt.imshow(X)
print(train_dataset.return_class(y))

In [92]:
from torchvision import transforms

my_transforms =  transforms.Compose([
                    transforms.Resize((96,96)),
                    transforms.RandomHorizontalFlip(),
                    transforms.RandomVerticalFlip(),
                    transforms.ToTensor()]
                    #transforms.Normalize(mean=[0.485, 0.456, 0.406],
                    #             std=[0.229, 0.224, 0.225])]
)

In [93]:
X = my_transforms(X)

In [94]:
plt.imshow(X.permute(1,2,0))

This may seem to preserve the details so lets try and create a resnet around it.

# Resnet model
https://towardsdatascience.com/residual-network-implementing-resnet-a7da63c7b278

Creating a class for automatic padding. `partial` is found in `functools`, it is used to 'implement' a function with custom parameters that is already in scope. 

In [96]:
from functools import partial
class Conv2dAuto(nn.Conv2d):
    def __init__(self, *args, **kwargs ):
        super().__init__(*args, **kwargs)
        self.padding = (self.kernel_size[0]//2, self.kernel_size[1]//2)
        
conv3x3 = partial(Conv2dAuto, kernel_size=3, bias=False)

In [99]:
conv = conv3x3(in_channels=32, out_channels=64)

print(conv)
del conv

Using module dict to create a dictionary of activation functions. https://towardsdatascience.com/pytorch-how-and-when-to-use-module-sequential-modulelist-and-moduledict-7a54597b5f17

In [101]:
def activation_function(activation):
    return nn.ModuleDict([
        ['relu',nn.ReLU(inplace=True)],
        ['leaky_relu',nn.LeakyReLU(negative_slope=0.01, inplace=True)],
        ['selu', nn.SELU(inplace=True)],
        ['none', nn.Identity()]
    ])[activation]

In [103]:
class ResidualModule(nn.Module):
    def __init__(self, in_channels, out_channels, activation='relu'):
        super().__init__()
        self.in_channels, self.out_channels, self.activation = in_channels, out_channels, activation
        self.blocks = nn.Identity()
        self.activate = activation_function(activation)
        self.shortcut = nn.Identity()
    
    def forward(self, x):
        residual = x
        if self.should_apply_shortcut: 
            residual = self.shortcut(x)
        x = self.blocks(x)
        x += residual
        return self.activate(x)
    
    @property
    def should_apply_shortcut(self):
        return self.in_channels != self.out_channels
        