In [None]:
!mkdir ~/.kaggle
!touch ~/.kaggle/kaggle.json


api_token = {"username":"","key":""}

import json

with open('/root/.kaggle/kaggle.json', 'w') as file:
    json.dump(api_token, file)

!chmod 600 ~/.kaggle/kaggle.json

In [None]:
! kaggle competitions download -c vk-made-sports-image-classification

In [None]:
! mkdir data

In [None]:
! unzip vk-made-sports-image-classification.zip -d data

In [None]:
! pip install -q transformers datasets

In [None]:
import tqdm
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd

from PIL import Image

import time

import torch
from torch import nn
# from torch.utils.data import Dataset
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader
from torchvision.transforms import (ToTensor, Normalize, Compose, Resize, CenterCrop,
    RandomResizedCrop, RandomRotation, RandomHorizontalFlip, RandomAutocontrast, ToPILImage)
from torchvision import models
from torchvision.io import read_image, ImageReadMode
from torch.optim import Adam
from torchvision.transforms import Resize
from torch.nn.functional import cross_entropy, relu
from sklearn import preprocessing
import os

from datasets import load_dataset, Image, Dataset

In [None]:
data_path = 'data/'

In [None]:
df_train = pd.read_csv(data_path + 'train.csv')
df_test = pd.read_csv(data_path + 'test.csv')

In [None]:
train = df_train.sample(frac = 0.8)
val = df_train.drop(train.index)

In [None]:
dataset_train = Dataset.from_dict(
    {"image": ['data/train/' + image_id for image_id in list(train.image_id)], 
     "label": train.label}).cast_column("image", Image()).class_encode_column("label")

dataset_val = Dataset.from_dict(
    {"image": ['data/train/' + image_id for image_id in list(val.image_id)], 
     "label": val.label}).cast_column("image", Image()).class_encode_column("label")

dataset_test = Dataset.from_dict(
    {"image": ['data/test/' + image_id for image_id in list(df_test.image_id)]}).cast_column("image", Image())


In [None]:
id2label = {id:label for id, label in enumerate(dataset_train.features['label'].names)}
label2id = {label:id for id,label in id2label.items()}

In [None]:
from transformers import ViTImageProcessor

processor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224-in21k")

In [None]:
from torchvision.transforms import (CenterCrop, 
                                    Compose, 
                                    Normalize, 
                                    RandomHorizontalFlip,
                                    RandomResizedCrop, 
                                    Resize, 
                                    ToTensor)

image_mean, image_std = processor.image_mean, processor.image_std
size = processor.size["height"]

normalize = Normalize(mean=image_mean, std=image_std)
_train_transforms = Compose(
        [
            RandomResizedCrop(size),
            RandomHorizontalFlip(),
            ToTensor(),
            normalize,
        ]
    )

_val_transforms = Compose(
        [
            Resize(size),
            CenterCrop(size),
            ToTensor(),
            normalize,
        ]
    )

def train_transforms(examples):
    examples['pixel_values'] = [_train_transforms(image.convert("RGB")) for image in examples['image']]
    return examples

def val_transforms(examples):
    examples['pixel_values'] = [_val_transforms(image.convert("RGB")) for image in examples['image']]
    return examples

In [None]:
dataset_train.set_transform(train_transforms)
dataset_val.set_transform(val_transforms)
dataset_test.set_transform(val_transforms)

In [None]:
from torch.utils.data import DataLoader
import torch

def collate_fn(examples):
    pixel_values = torch.stack([example["pixel_values"] for example in examples])
    labels = torch.tensor([example["label"] for example in examples])
    return {"pixel_values": pixel_values, "labels": labels}

train_dataloader = DataLoader(dataset_train, collate_fn=collate_fn, batch_size=4)

In [None]:
batch = next(iter(train_dataloader))
for k,v in batch.items():
  if isinstance(v, torch.Tensor):
    print(k, v.shape)

In [None]:
from transformers import ViTForImageClassification

model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224-in21k',
                                                  id2label=id2label,
                                                  label2id=label2id)

In [None]:
from transformers import TrainingArguments, Trainer

metric_name = "f1"

args = TrainingArguments(
    output_dir='results',
    save_strategy="epoch",
    evaluation_strategy="epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=4,
    num_train_epochs=5,
    weight_decay=0.01,
    load_best_model_at_end=True,
    metric_for_best_model=metric_name,
    logging_dir='logs',
    remove_unused_columns=False,
)

In [None]:
from sklearn.metrics import f1_score
import numpy as np

def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    predictions = np.argmax(predictions, axis=1)
    return dict(f1=f1_score(predictions, labels, average='micro'))

In [None]:
import torch

trainer = Trainer(
    model,
    args,
    train_dataset=dataset_train,
    eval_dataset=dataset_val,
    data_collator=collate_fn,
    compute_metrics=compute_metrics,
    tokenizer=processor,
)

In [None]:
%load_ext tensorboard
%tensorboard --logdir logs/

In [None]:
trainer.train()