In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [None]:
pip install einops

In [None]:
import torch
from torch import nn

from einops import rearrange, repeat
from einops.layers.torch import Rearrange
from typing import List, Union
from pathlib import Path
from torch.utils.data import random_split
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torch.optim as optim
import time
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision.io import read_image
from torchvision import transforms
import matplotlib.pyplot as plt
from torchvision import datasets
from torch.optim import Adam
from sklearn.model_selection import train_test_split

In [None]:
dataset = datasets.ImageFolder('/kaggle/input/iais22-birds/birds/birds/', transform=transforms.Compose([transforms.Resize((64,64)), transforms.ToTensor()]))
train_set, val_set = random_split(dataset, (int(len(dataset) * 0.9) + 1, int(len(dataset) * 0.1)))


In [None]:
train = DataLoader(dataset, batch_size=200, shuffle=True)
train_dataloader = DataLoader(train_set, batch_size=200, shuffle=True)
val_dataloader = DataLoader(val_set, batch_size=200, shuffle=True)

In [None]:
final_dataloader = DataLoader(dataset, batch_size=200)

In [None]:
def pair(t):
    return t if isinstance(t, tuple) else (t, t)

In [None]:
class PreNorm(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.fn = fn
    def forward(self, x, **kwargs):
        return self.fn(self.norm(x), **kwargs)

class MLP(nn.Module):
    def __init__(self, dim, hidden_dim, dropout = 0.):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, dim),
            nn.Dropout(dropout)
        )
    def forward(self, x):
        return self.net(x)

class Attention(nn.Module):
    def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
        super().__init__()
        inner_dim = dim_head *  heads
        project_out = not (heads == 1 and dim_head == dim)

        self.heads = heads
        self.scale = dim_head ** -0.5

        self.attend = nn.Softmax(dim = -1)
        self.dropout = nn.Dropout(dropout)

        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)

        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim),
            nn.Dropout(dropout)
        ) if project_out else nn.Identity()

    def forward(self, x):
        qkv = self.to_qkv(x).chunk(3, dim = -1)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)

        dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale

        attn = self.attend(dots)
        attn = self.dropout(attn)

        out = torch.matmul(attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)

class Encoder(nn.Module):
    def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
        super().__init__()
        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)),
                PreNorm(dim, MLP(dim, mlp_dim, dropout = dropout))
            ]))
    def forward(self, x):
        for attn, ff in self.layers:
            x = attn(x) + x
            x = ff(x) + x
        return x

class ViT(nn.Module):
    def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.):
        super().__init__()
        #Obtiene la altura y el ancho de la imagen
        image_height, image_width = pair(image_size)
        #Obtiene la altura y el ancho del patch
        patch_height, patch_width = pair(patch_size)
        
        #Comprueba que la altura de la imagen sea divisible por la altura del patch y lo mismo con el ancho
        assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
        
        #Obtiene el numero de veces en el que se ha dividido la imagen
        num_patches = (image_height // patch_height) * (image_width // patch_width)
        #Las dimensiones de cad subtrozo de imagen
        patch_dim = channels * patch_height * patch_width
        assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'
        
        #Capa EMBEDDED PATCHES
        self.to_patch_embedding = nn.Sequential(
            Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width),
            nn.Linear(patch_dim, dim),
        )

        self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
        self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
        self.dropout = nn.Dropout(emb_dropout)

        #ENCODER
        self.encoder = Encoder(dim, depth, heads, dim_head, mlp_dim, dropout)

        self.pool = pool
        self.to_latent = nn.Identity()
        
        #MLP HEAD
        self.mlp_head = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, num_classes)
        )

    def forward(self, img):
        #EMBEDDED PATCHES
        x = self.to_patch_embedding(img)
        b, n, _ = x.shape
        cls_tokens = repeat(self.cls_token, '1 n d -> b n d', b = b)
        x = torch.cat((cls_tokens, x), dim=1)
        x += self.pos_embedding[:, :(n + 1)]
        x = self.dropout(x)
        
        #Se aplica el ENCODER
        x = self.encoder(x)
        
        #Si se ha declarado que se use la media se usa, si no, se toma la primera columna
        x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]
        
        #Función de activación la identidad
        x = self.to_latent(x)
        return self.mlp_head(x)

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using {device} device")

In [None]:
model = ViT(
    image_size = 64,
    patch_size = 8,
    num_classes = 400,
    dim = 256,
    depth = 2,
    heads = 12,
    mlp_dim = 512,
    dropout = 0.1,
    emb_dropout = 0.1
)

In [None]:
learning_rate = 1e-3

loss_fn = nn.CrossEntropyLoss()
optimizer = Adam(params=model.parameters(), lr=learning_rate, amsgrad=False)


def train_loop(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)
    for batch, (X, y) in enumerate(dataloader):
        # Compute prediction and loss
        pred = model(X)
        loss = loss_fn(pred, y)

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if batch % 100 == 0:
            loss, current = loss.item(), batch * len(X)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")


def test_loop(dataloader, model, loss_fn):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    test_loss, correct = 0, 0

    with torch.no_grad():
        for X, y in dataloader:
            pred = model(X)
            test_loss += loss_fn(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()

    test_loss /= num_batches
    correct /= size
    print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")
    
    


epochs = 15
for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    train_loop(train, model, loss_fn, optimizer)
    test_loop(val_dataloader, model, loss_fn)
print("Done!")

In [None]:
PATH = './transformer-15epo-12h.pth'
torch.save(model.state_dict(), PATH)

In [None]:
red = ViT(
    image_size = 64,
    patch_size = 8,
    num_classes = 400,
    dim = 256,
    depth = 2,
    heads = 12,
    mlp_dim = 512,
    dropout = 0.1,
    emb_dropout = 0.1
)
red.load_state_dict(torch.load(PATH))

In [None]:
from PIL import Image
class BirdsDatasetTest(torch.utils.data.Dataset):
    def __init__(self, path: Union[Path, str],
                transform: Union['Transform', List['Transform']] = transforms.Compose([transforms.Resize((64,64)),transforms.ToTensor()])):
        self.path = Path(path)
        self.labels = [p.name for p in path.glob('x')]
        self.images = list(path.glob('*/*.jpg'))
        self.transform = transform
        
    
    def __len__(self) -> int:
        return len(self.images)
    
    def __getitem__(self, index:int) :
        image_path = self.images[index]
        image = self.transform(Image.open(str(image_path)))
        archivo = os.path.basename(image_path)
        id = archivo.split(sep=".")[0]
        return image,id

In [None]:
test_dat = BirdsDatasetTest(path = Path('/kaggle/input/iais22-birds/submission_test/'))
clases = dataset.classes
dataloader = DataLoader(test_dat, batch_size=1)

In [None]:
import csv
import os
trans = transforms.Compose([transforms.Resize((64,64)), transforms.ToTensor()])
with open('submission.csv', 'w') as csvfile:
    fieldnames = ['Id', 'Category']
    writer = csv.DictWriter(csvfile, fieldnames = fieldnames)
    writer.writeheader()
    for image,id in dataloader:
        pred = red(image)
        #predicted = torch.argmax(pred, 1)
        #print(int(id[0]))
        categoria = clases[pred.argmax(1)]
        writer.writerow({'Id': int(id[0]), 'Category': categoria})