In [None]:
import argparse
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets
from torch.autograd import Variable
from tqdm import tqdm
import matplotlib.pyplot as plt
from time import time

from train import train_model
from model import *
from data import *

## Parameters

In [None]:
data_dir = 'bird_dataset'
batch_size = 32
epochs = 20
lr = 0.005
momentum = 0.9
seed = 42
save_dir = 'experiment'

use_cuda = torch.cuda.is_available()
torch.manual_seed(seed)

## Data

In [None]:
mean_nums = [0.485, 0.456, 0.406]
std_nums = [0.229, 0.224, 0.225]

train_transforms = transforms.Compose([
    transforms.Resize(size=256),
    transforms.CenterCrop(size=224),
    transforms.RandomRotation(degrees=10),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean_nums, std_nums)
])

val_transforms = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean_nums, std_nums)
])

In [None]:
my_datasets = {
    'train': datasets.ImageFolder(
        data_dir + '/train_images/',
        transform=train_transforms
    ),
    'val': datasets.ImageFolder(
        data_dir + '/val_images/',
        transform=val_transforms
    )
}

dataloaders = {
    'train': torch.utils.data.DataLoader(
        my_datasets['train'],
        batch_size=batch_size,
        shuffle=True, 
        num_workers=4
    ),
    'val': torch.utils.data.DataLoader(
        my_datasets['val'],
        batch_size=batch_size,
        shuffle=False, 
        num_workers=4
    )
}

## Model

In [None]:
model = torch.hub.load('facebookresearch/WSL-Images', 'resnext101_32x8d_wsl')

for child, layer in model.named_children():
    if child != 'layer4':
        for param in layer.parameters():
            param.requires_grad = False
        
n_features = model.fc.in_features
model.fc = nn.Linear(n_features, 20)

if use_cuda:
    print('Using GPU')
    model.cuda()
else:
    print('Using CPU')
    
optimizer = optim.SGD(
    model.parameters(), 
    lr=0.001,
    momentum=momentum,
    weight_decay=0.00005
)

criterion = torch.nn.CrossEntropyLoss().cuda()
lr_scheduler = optim.lr_scheduler.StepLR(optimizer, 1, 0.9)

## Train

In [None]:
model, _ = train_model(
    model, 
    dataloaders, 
    criterion, 
    optimizer, 
    lr_scheduler,
    epochs
)

In [None]:
torch.save(model.state_dict(), save_dir + '/resnext.pth')

## K Folds

In [None]:
from sklearn.model_selection import KFold

In [None]:
kf = KFold(n_splits=8, shuffle=True)
dataset = datasets.ImageFolder(
data_dir + '/images_crop/',
transform=train_transforms
)
for i, (train_index, test_index) in enumerate(kf.split(dataset)):
    if i < 5:
        train = torch.utils.data.Subset(dataset, train_index)
        test = torch.utils.data.Subset(dataset, test_index)

        dataloaders = {
            'train': torch.utils.data.DataLoader(
                train,
                batch_size=batch_size,
                shuffle=True, 
                num_workers=4
            ),
            'val': torch.utils.data.DataLoader(
                test,
                batch_size=batch_size,
                shuffle=False, 
                num_workers=4
            )
        }

        print('Fold : {}'.format(i + 1))

        model = Resnext101WSL(last_conv=True)
        model.cuda()

        optimizer = optim.SGD(
            model.parameters(), 
            lr=0.045,
            momentum=momentum,
            weight_decay=0.00005
        )

        criterion = torch.nn.CrossEntropyLoss().cuda()
        lr_scheduler = optim.lr_scheduler.StepLR(optimizer, 1, 0.6)

        model, _ = train_model(
            model, 
            dataloaders, 
            criterion, 
            optimizer, 
            lr_scheduler,
            7,
            reduce_lr_on_plateau=False
        )

        torch.save(model.state_dict(), save_dir + '/resnextWSL{}.pth'.format(i))