<a href="https://colab.research.google.com/github/brandleyzhou/summerschool_AF/blob/main/basic_toturial.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from google.colab import drive
mount='/content/gdrive'
print("Colab: mounting Google drive on ", mount)
drive.mount(mount)
import os
drive_root = mount + "/My Drive/summerschool"
%cd $drive_root


In [None]:
import torch
import torch.utils.data as data
from torchvision import transforms


import os
import random
import numpy as np
import copy
from PIL import Image  # using pillow-simd for increased speed

def pil_loader(path):
    with open(path, 'rb') as f:
        with Image.open(f) as img:
            return img.convert('RGB')

class FishDataset(data.Dataset):
    def __init__(self, data_path, filenames, is_train = False):
        super(FishDataset, self).__init__()
        self.data_path = data_path
        self.filenames = filenames
        self.interp = Image.ANTIALIAS
        self.is_train = is_train
        self.loader = pil_loader
        self.to_tensor = transforms.ToTensor()

        try:
            self.brightness = (0.8, 1.2)
            self.contrast = (0.8, 1.2)
            self.saturation = (0.8, 1.2)
            self.hue = (-0.1, 0.1)
            transforms.ColorJitter.get_params(
                self.brightness, self.contrast, self.saturation, self.hue)
        except TypeError:
            self.brightness = 0.2
            self.contrast = 0.2
            self.saturation = 0.2
            self.hue = 0.1

    def preprocess(self, img, color_aug):
        img_aug = self.to_tensor(color_aug(img))
        print(type(img_aug))
        return img_aug
    
    def generate_gt(self, folder):
        gt_list = ["saithe", "herring", "grey_gurnard", "norway_pout", "anchovy",
                    "red_mullet", "cod", "haddock", "sardine", "mackerel"]
        #gt = torch.from_numpy(gt_list.index(folder))
        gt = gt_list.index(folder)
        return gt
    
    def __getitem__(self, index):
        do_color_aug = self.is_train and random.random() > 0.5
        do_flip = self.is_train and random.random() > 0.5

        line = self.filenames[index].split()
        folder = line[0]
        img_name = line[1]
        gt = self.generate_gt(folder)
        color = self.loader(os.path.join(self.data_path, folder, img_name))
        
        if do_flip:
            color = color.transpose(Image.FLIP_LEFT_RIGHT) 
        img = color 
        
        if do_color_aug:
            color_aug = transforms.ColorJitter.get_params(
                self.brightness, self.contrast, self.saturation, self.hue)
        else:
            color_aug = (lambda x: x)
        #img_aug = self.preprocess(img, color_aug)
        #del img 
        return self.to_tensor(img), gt 
    def __len__(self):
        return len(self.filenames)


In [None]:
from __future__ import absolute_import
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
#from .dataset import FishDataset
from torch.utils.data import DataLoader
from tqdm import tqdm
data_path = 'Archive'
train_set_path = 'Archive/train_val.txt'
batch_size = 16
num_epochs = 5
learning_rate = 1e-4
def readlines(filename):
    with open(filename, 'r') as f:
        lines = f.read().splitlines()
    return lines
#============================= data loader =======================
train_files = readlines(train_set_path)
print(len(train_files))
trainset = FishDataset(data_path, train_files, is_train = True)
train_loader = DataLoader(trainset, batch_size, True, 
        num_workers = 1, pin_memory=True, drop_last=True)

#=================================================================

#============================= network ===========================
class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()

        self.conv1 = nn.Conv2d(3, 32, kernel_size=7, stride=2, padding=3, bias=False)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.conv2 = nn.Conv2d( 32, 64, kernel_size = 3, stride = 2 )
        self.d1 = nn.Linear(64 * 32 * 32, 128)
        self.d2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.maxpool(x)
        x = self.conv2(x)
        x = self.maxpool(x)
        x = x.flatten(start_dim = 1)

        x = self.d1(x)
        x = F.relu(x)

        logits = self.d2(x)
        out = F.softmax(logits, dim=1)
        return out
#=====================================================================
model = MyModel().cuda()
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
#=====================================================================

def get_accuracy(logit, target, batch_size):
    ''' Obtain accuracy for training round '''
    corrects = (torch.max(logit, 1)[1].view(target.size()).data == target.data).sum()
    accuracy = 100.0 * corrects/batch_size
    return accuracy.item()

for epoch in range(num_epochs):
    train_running_loss = 0
    train_acc = 0
    model.train()
    for i, (img, label) in enumerate(train_loader):
        img = img.cuda()
        label = label.cuda()
        out = model(img)
        loss = criterion(out, label)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        train_running_loss += loss.detach().item()
        train_acc += get_accuracy(out, label, batch_size)
    
    model.eval()
    print('Epoch: %d | Loss: %.4f | Train Accuracy: %.2f' \
          %(epoch, train_running_loss / i, train_acc/i))         

1886
Epoch: 0 | Loss: 2.0588 | Train Accuracy: 43.97
Epoch: 1 | Loss: 1.8797 | Train Accuracy: 60.83
Epoch: 2 | Loss: 1.8367 | Train Accuracy: 65.46
Epoch: 3 | Loss: 1.8208 | Train Accuracy: 67.03
Epoch: 4 | Loss: 1.8089 | Train Accuracy: 67.83
