In [None]:
import torch
from torch import nn
from torchvision import transforms, datasets
from torchvision.models import resnet18, ResNet18_Weights
from torch.utils.data import random_split, DataLoader
from sklearn.metrics import f1_score, accuracy_score, precision_score, recall_score
from tqdm import tqdm
import numpy as np

In [None]:
# Добавим в модель softmax, и укажем необходимое количество классов

class Model(nn.Module):
    def __init__(self, num_classes = 3):
        super().__init__()
        self.resnet = resnet18(weights = ResNet18_Weights.DEFAULT)
        in_features = self.resnet.fc.in_features
        self.resnet.fc = nn.Sequential(
            nn.Linear(in_features, num_classes),
            nn.Softmax())
    def forward(self, x):
        return self.resnet(x)

In [None]:
# загружаем модель
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
model = Model().to(device)
model

Model(
  (resnet): ResNet(
    (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_runnin

In [None]:
# если код ниже не работает, запусти это

!rm -R ./train/ipynb_checkpoints
!ls ./train/ -a

rm: cannot remove './train/ipynb_checkpoints': No such file or directory
ls: cannot access './train/': No such file or directory


In [None]:
# загружаем данные

transform = transforms.Compose([transforms.RandomRotation(0.3),
                                transforms.ToTensor(),
                                transforms.Resize((224,224), antialias = True),
                                transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
                               ])

dataset = datasets.ImageFolder(root = './train', transform = transform)
num_classe = len(dataset.classes)
print("Classes: ", dataset.classes)

train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size

train, val = torch.utils.data.random_split(dataset, [train_size, val_size])

train_dataloader = torch.utils.data.DataLoader(dataset = train,
                                           batch_size = 32,
                                           shuffle = True)

val_dataloader = torch.utils.data.DataLoader(dataset = val,
                                         batch_size = 32,
                                         shuffle = True)

Classes:  ['left', 'other', 'right']


In [None]:
# обучаем

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
num_epochs = 50
best_f1 = 0

for epoch in range(num_epochs):
        train_bar = tqdm(train_dataloader, total=len(train_dataloader))
        model.train()
        print(f'Epoch: {epoch+1}, num epochs: {num_epochs}')
        for step, batch in enumerate(train_bar):
            images = batch[0].to(device)
            labels = batch[1].to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)
            train_bar.set_description("epoch {} loss {}".format(epoch,loss))
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        val_bar = tqdm(val_dataloader,total=len(val_dataloader))
        model.eval()
        logits = []
        y_trues = []
        eval_loss = 0
        for step, batch in enumerate(val_bar):
            with torch.no_grad():
                images = batch[0].to(device)
                labels = batch[1].to(device)
                outputs = model(images)
                val_loss = criterion(outputs, labels)
                logits.append(outputs.cpu().numpy())
                y_trues.append(labels.cpu().numpy())
        logits = np.concatenate(logits,0)
        y_trues = np.concatenate(y_trues,0)
        best_threshold = 0.5
        y_preds = logits[:,1]>best_threshold
        recall = recall_score(y_trues, y_preds, average = "macro")
        precision = precision_score(y_trues, y_preds, average = "macro")
        f1 = f1_score(y_trues, y_preds, average = "macro")
        result = {
            "eval_recall": float(recall),
            "eval_precision": float(precision),
            "eval_f1": float(f1),
            "eval_threshold": best_threshold,
            "eval_loss": eval_loss
        }

        print("***** Eval results *****")
        for key in  sorted(result.keys()):
            print(key, str(round(result[key],4)))
        if result['eval_f1']>best_f1:
            best_f1=result['eval_f1']
            print("  "+"*"*20)
            print("  Best f1:",round(best_f1,4))
            print("  "+"*"*20)



  0%|          | 0/1 [00:00<?, ?it/s]

Epoch: 1, num epochs: 50


epoch 0 loss 0.5516905784606934: 100%|██████████| 1/1 [00:08<00:00,  8.40s/it]
100%|██████████| 1/1 [00:00<00:00,  1.12it/s]
  _warn_prf(average, modifier, msg_start, len(result))


***** Eval results *****
eval_f1 0.4
eval_loss 0
eval_precision 0.3333
eval_recall 0.5
eval_threshold 0.5
  ********************
  Best f1: 0.4
  ********************




Epoch: 2, num epochs: 50


epoch 1 loss 0.6470877528190613: 100%|██████████| 1/1 [00:08<00:00,  8.24s/it]
100%|██████████| 1/1 [00:00<00:00,  1.75it/s]
  _warn_prf(average, modifier, msg_start, len(result))


***** Eval results *****
eval_f1 0.1667
eval_loss 0
eval_precision 0.1111
eval_recall 0.3333
eval_threshold 0.5




Epoch: 3, num epochs: 50


epoch 2 loss 0.5550135970115662: 100%|██████████| 1/1 [00:06<00:00,  6.21s/it]
100%|██████████| 1/1 [00:00<00:00,  1.78it/s]
  _warn_prf(average, modifier, msg_start, len(result))


***** Eval results *****
eval_f1 0.1667
eval_loss 0
eval_precision 0.1111
eval_recall 0.3333
eval_threshold 0.5




Epoch: 4, num epochs: 50


epoch 3 loss 0.5556450486183167: 100%|██████████| 1/1 [00:08<00:00,  8.21s/it]
100%|██████████| 1/1 [00:00<00:00,  1.78it/s]
  _warn_prf(average, modifier, msg_start, len(result))


***** Eval results *****
eval_f1 0.1667
eval_loss 0
eval_precision 0.1111
eval_recall 0.3333
eval_threshold 0.5




Epoch: 5, num epochs: 50


epoch 4 loss 0.5553067922592163: 100%|██████████| 1/1 [00:06<00:00,  6.11s/it]
100%|██████████| 1/1 [00:00<00:00,  1.79it/s]
  _warn_prf(average, modifier, msg_start, len(result))


***** Eval results *****
eval_f1 0.1667
eval_loss 0
eval_precision 0.1111
eval_recall 0.3333
eval_threshold 0.5




Epoch: 6, num epochs: 50


epoch 5 loss 0.6202312111854553: 100%|██████████| 1/1 [00:08<00:00,  8.17s/it]
100%|██████████| 1/1 [00:00<00:00,  1.75it/s]
  _warn_prf(average, modifier, msg_start, len(result))


***** Eval results *****
eval_f1 0.1667
eval_loss 0
eval_precision 0.1111
eval_recall 0.3333
eval_threshold 0.5




Epoch: 7, num epochs: 50


epoch 6 loss 0.554892361164093: 100%|██████████| 1/1 [00:06<00:00,  6.18s/it]
100%|██████████| 1/1 [00:00<00:00,  1.78it/s]
  _warn_prf(average, modifier, msg_start, len(result))


***** Eval results *****
eval_f1 0.1667
eval_loss 0
eval_precision 0.1111
eval_recall 0.3333
eval_threshold 0.5




Epoch: 8, num epochs: 50


epoch 7 loss 0.5528573393821716: 100%|██████████| 1/1 [00:08<00:00,  8.00s/it]
100%|██████████| 1/1 [00:00<00:00,  1.81it/s]
  _warn_prf(average, modifier, msg_start, len(result))


***** Eval results *****
eval_f1 0.1667
eval_loss 0
eval_precision 0.1111
eval_recall 0.3333
eval_threshold 0.5




Epoch: 9, num epochs: 50


epoch 8 loss 0.55581134557724: 100%|██████████| 1/1 [00:06<00:00,  6.11s/it]
100%|██████████| 1/1 [00:00<00:00,  1.74it/s]
  _warn_prf(average, modifier, msg_start, len(result))


***** Eval results *****
eval_f1 0.1667
eval_loss 0
eval_precision 0.1111
eval_recall 0.3333
eval_threshold 0.5




Epoch: 10, num epochs: 50


epoch 9 loss 0.5522676706314087:   0%|          | 0/1 [00:07<?, ?it/s]


KeyboardInterrupt: ignored