In [1]:
!pip install -U wandb -q
!pip install -U albumentations -q
!pip install -U transformers datasets -q

In [2]:
import numpy as np
import torch
import torchvision.transforms as transforms
import torchvision.models as models
import zipfile
import os
import wandb
import copy
import matplotlib.pyplot as plt
import albumentations as A
from albumentations.pytorch import ToTensorV2
from albumentations.core.transforms_interface import ImageOnlyTransform
from PIL import Image, ImageOps, ImageEnhance
from torchvision import datasets
from tqdm import tqdm
import cv2

In [3]:
!nvidia-smi

# DATA

In [13]:
size = 224
data_transforms = transforms.Compose([
    transforms.Resize((size, size)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

train_transforms = A.Compose([
    A.CoarseDropout(p=0.1),
    A.GaussNoise(p=0.2),
    A.ElasticTransform(p=0.33),
    A.Rotate(),
    A.ShiftScaleRotate(p=0.2),
    A.RGBShift(p=0.5),
    A.RandomBrightnessContrast(p=0.5),
    A.Resize(32, 32,p=0.1), # Randomly downsample
    A.Resize(64, 64,p=0.2), # Randomly downsample
    A.Resize(size,size,cv2.INTER_LANCZOS4),
    A.Downscale(scale_min=0.1,scale_max=0.25,interpolation=cv2.INTER_LANCZOS4,p=0.3),
    A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
    ToTensorV2(),
])

In [14]:
class AugmentedDS(torch.utils.data.Dataset):
    def __init__(self, dataset, transform=None):
        self.dataset = dataset
        self.transform = transform

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        image, label = self.dataset[idx]
        if self.transform is not None:
            image = self.transform(image=np.array(image))["image"]
        return image, label
    
def visualize_augmentations(dataset, idx=0, samples=10, cols=5):
    dataset = copy.deepcopy(dataset)
    dataset.transform = A.Compose([t for t in dataset.transform if not isinstance(t, (A.Normalize, ToTensorV2))])
    rows = samples // cols
    figure, ax = plt.subplots(nrows=rows, ncols=cols, figsize=(12, 6))
    for i in range(samples):
        image, _ = dataset[idx]
        ax.ravel()[i].imshow(image)
        ax.ravel()[i].set_axis_off()
    plt.tight_layout()
    plt.show()

# MODEL

In [15]:
from transformers import BeitForImageClassification, ViTForImageClassification, DeiTForImageClassification
nclasses = 20
class BEiTNet(torch.nn.Module):
    def __init__(self):
        super(BEiTNet, self).__init__()
        self.beit =  BeitForImageClassification.from_pretrained(f'microsoft/beit-large-patch16-224')
        # self.beit =  ViTForImageClassification.from_pretrained(f'google/vit-large-patch16-224')
        # self.beit =  DeiTForImageClassification.from_pretrained('facebook/deit-base-distilled-patch16-224')
        self.beit.classifier = torch.nn.Sequential(
            torch.nn.Linear(self.beit.classifier.in_features,512),
            torch.nn.Dropout(p=0.4),
            torch.nn.ReLU(),
            torch.nn.Linear(512,nclasses),
        )

    def forward(self, x):
        return self.beit(x).logits

# Training

In [16]:
torch.cuda.empty_cache()

# Training settings
config = {
    "data": '../input/bird-dataset-cropped/bird_dataset',
    "batch_size": 10,
    "epochs": 150,
    "lr": 1e-2,
    "momentum": 0.9,
    "seed": 9823,
    "weight_decay" : 3e-5,
    "experiment": 'experiment',
    "checkpoint": None,
    "clipping": 1.
}

# torch.manual_seed(config["seed"])
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Create experiment folder
if not os.path.isdir(config["experiment"]):
    os.makedirs(config["experiment"])

train_dataset = datasets.ImageFolder(config["data"] + '/train_images',transform=None)
val_dataset = datasets.ImageFolder(config["data"] + '/val_images',transform=data_transforms)

train_dataset = AugmentedDS(train_dataset,transform=train_transforms)

train_loader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=config["batch_size"], 
    shuffle=True, 
    num_workers=1,
    pin_memory=True)
val_loader = torch.utils.data.DataLoader(
    val_dataset,
    batch_size=config["batch_size"],
    shuffle=False, 
    num_workers=1,
    pin_memory=True)

visualize_augmentations(train_dataset,20)

# Loss function + weighted classes
weights = torch.ones(nclasses)
weights[16] = 0.7
weights[13] = 0.8
criterion = torch.nn.CrossEntropyLoss(weight=weights.to(device),reduction='mean')

# Neural network and optimizer
model = BEiTNet()
optimizer = torch.optim.SGD(model.parameters(), lr=config["lr"], momentum=config["momentum"], weight_decay=config["weight_decay"])
# optimizer = torch.optim.Adam(model.parameters(), lr=config["lr"], weight_decay=config["weight_decay"])
# scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer,150)
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,[10,20,30],0.5)
scaler = torch.cuda.amp.GradScaler()

if config["checkpoint"] is not None:
    checkpoint = torch.load(config["checkpoint"],map_location=device)
    model.load_state_dict(checkpoint)
    del checkpoint

model.to(device);

In [None]:
def pil_loader(path):
    with open(path, 'rb') as f:
        with Image.open(f) as img:
            return img.convert('RGB')

def eval(outfile):
    test_dir = config["data"] + '/test_images/mistery_category'
    output_file = open(outfile, "w")
    output_file.write("Id,Category\n")
    for f in os.listdir(test_dir):
        if 'jpg' in f:
            data = data_transforms(pil_loader(test_dir + '/' + f))
            data = data.view(1, data.size(0), data.size(1), data.size(2)).to(device)
            output = model(data)
            pred = output.data.max(1, keepdim=True)[1]
            output_file.write("%s,%d\n" % (f[:-4], pred))

    output_file.close()

def train(epoch):
    model.train()
    epoch_loss = 0
    for batch_idx, (data, target) in enumerate(train_loader):
        optimizer.zero_grad()
        data = data.to(device)
        target = target.to(device)
        with torch.cuda.amp.autocast():
            output = model(data)
            loss = criterion(output, target)
        scaler.scale(loss).backward()
        scaler.unscale_(optimizer)
        torch.nn.utils.clip_grad_norm_(model.parameters(), config["clipping"])
        scaler.step(optimizer)
        scaler.update()
        epoch_loss += loss.item()
        del data
        del target
    # scheduler.step()
    return epoch_loss/len(train_loader)

def validation():
    model.eval()
    validation_loss = 0
    correct = 0
    with torch.no_grad():
      for data, target in val_loader:
          data = data.to(device)
          target = target.to(device)
          output = model(data)
          # sum up batch loss
          validation_loss += criterion(output, target).item()
          # get the index of the max log-probability
          pred = output.max(1, keepdim=True)[1]
          correct += pred.eq(target.data.view_as(pred)).cpu().sum().item()
          del data
          del target

    validation_loss /= len(val_loader.dataset)
    accuracy = 100. * correct / len(val_loader.dataset)
    return validation_loss, accuracy

run = wandb.init(
  project="recvis-tp3",
  config=config
)

with tqdm(range(config["epochs"])) as t:
    acc_best = 0
    loss_best = 100
    for epoch in t:
        loss_train = train(epoch) # Train
        loss_val, acc_val = validation() # Test
        scheduler.step()
        # Save current checkpoint
        if (acc_val > 90):
            eval(f'experiment/model_{epoch}_{int(100.*acc_val)/100.}.csv')
        if ((acc_val > acc_best) or ((acc_val == acc_best) and (loss_val < loss_best))) and (acc_val > 90):
            torch.save(model.state_dict(), f'experiment/best.pth')
            acc_best = acc_val
            loss_best = loss_val
        t.set_postfix(loss_train=loss_train,loss_val=loss_val,acc_val=acc_val)
        wandb.log({
            "loss_train" : loss_train,
            "loss_test" : loss_val,
            "acc_test" : acc_val,
            "lr": scheduler.get_last_lr()[0]
        })