In [None]:
%load_ext autoreload

In [None]:
%autoreload 2

In [None]:
import pandas as pd
import ipyplot
import torch
import pytorch_lightning as pl
from sklearn.metrics import f1_score
%matplotlib inline

import torchvision.transforms as transforms
from torch.nn import functional as F
from torch import nn
from torch.nn import *
import pytorch_lightning as pl
import timm

from torch.utils.data import Dataset
import torchvision
import torchvision.transforms as transforms
from torch.nn import functional as F
from torch import nn
import pytorch_lightning as pl
from pytorch_lightning.core.lightning import LightningModule
import timm
import pandas as pd
import numpy as np
from matplotlib import pyplot as plt
import ipyplot
from skimage import io
from sklearn.metrics import balanced_accuracy_score
from tqdm.auto import tqdm

## Model

In [None]:
class ImageClassifier(pl.LightningModule):
    def __init__(self, trunk=None, class_weight=None, learning_rate=1e-3):
        super().__init__()
        self.class_weight = class_weight
        self.trunk = trunk or timm.create_model('mobilenetv2_100', pretrained=True, num_classes=2)
        self.learning_rate =  learning_rate

    def forward(self, x):
        return self.trunk(x)

    def predict_proba(self, x):
        probabilities = nn.functional.softmax(self.forward(x), dim=1)
        return probabilities

    def predict(self, x):
        return torch.max(self.forward(x), 1)[1]

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.parameters(),
                                      lr=self.learning_rate)
        return optimizer

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = nn.CrossEntropyLoss(weight=self.class_weight)(y_hat, y)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = nn.CrossEntropyLoss(weight=self.class_weight)(y_hat, y)
        self.log("val_loss", loss)

    def configure_optimizers(self):
        optimizer = optimizer = torch.optim.Adam(self.layer.parameters(), lr=1e-3)
        lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1)
        return [optimizer], [lr_scheduler]
        

In [None]:
MAX_EPOCHS = 100
BATCH_SIZE = 128
LEARNING_RATE = 1e-3

In [None]:
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.RandomHorizontalFlip(p=0.4),
     transforms.RandomApply([transforms.ColorJitter(
            0.8 * self.jitter_strength,
            0.8 * self.jitter_strength,
            0.8 * self.jitter_strength,
            0.2 * self.jitter_strength,
        )], p=0.5),
     transforms.RandomGrayscale(p=0.2),
     transforms.Resize((100, 100))
     ]
)

In [None]:
class ImagesDataset(Dataset):
    def __init__(self, image_paths, labels = None, transform=None):
        super().__init__()
        self.image_paths = image_paths
        self.labels = labels
        self.targets = self.labels
        self.transform = transform
        
        if self.labels is not None:
            assert len(self.image_paths) == len(self.labels)

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        if self.labels is not None:
            label = self.labels[idx]
        else:
            label = -1 # Carefully handle that!
        img = io.imread(self.image_paths[idx])
        img = img[...,:3] # Some images have 4 channels, fix that
        if self.transform:
            img = self.transform(img)

        return img, label

def evaluate_model(model, dataset, batch_size=32, num_workers=4):
    model.eval()
    loader = torch.utils.data.DataLoader(dataset, 
                                           batch_size=batch_size,
                                           num_workers=num_workers,
                                           shuffle=False)
    predictions = []
    labels = []
    with torch.no_grad():
        for x, y in tqdm(loader):
            prediction = model.predict(x).numpy()
            predictions += list(prediction)
            labels += list(y.numpy())
            
    return labels, predictions



In [None]:
train_dataset = ImagesDataset(train_image_paths, train_labels, transform=transform)

test_dataset = ImagesDataset(test_image_paths, transform=transform)

train_loader = torch.utils.data.DataLoader(train_dataset, 
                                           batch_size= BATCH_SIZE,
                                           num_workers=16,
                                           shuffle=False)

test_loader = torch.utils.data.DataLoader(test_dataset,
                                          batch_size=1,
                                          num_workers=16,
                                          shuffle=False)


In [None]:
model = ImageClassifier(learning_rate=params.LEARNING_RATE)

trainer = pl.Trainer(
    max_epochs=params.MAX_EPOCHS,
    log_every_n_steps=-1,
    gpus=-1,
)
trainer.fit(model, train_loader)

In [None]:
labels, predictions = evaluate_model(model, train_dataset)

f1_score(labels, predictions, labels=1, average='binary')

## Дообучение

На самом деле нужно просто скачать модель и продолжить обучать на новом датасете, который мы разметили на Толоке.

In [None]:
model = ImageClassifier.load_from_checkpoint('model.ckpt')

In [None]:
new_train_dataset = ImagesDataset(new_train_image_paths, new_train_labels, transform=transform)
new_train_loader = torch.utils.data.DataLoader(new_train_dataset,
                                               batch_size = BATCH_SIZE,
                                               num_workers=16,
                                               shuffle=False)

trainer.fit(model, new_train_loader)