<a href="https://colab.research.google.com/github/irina-lebedeva/Comboloss/blob/main/PyTorch_transfer_learning.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
from __future__ import print_function, division
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import matplotlib.pyplot as plt
from torchvision import transforms, datasets, models
import time
from torch.optim import lr_scheduler
import os

In [3]:
data_transform  = {
    "train": transforms.Compose([transforms.ToTensor(),
        transforms.RandomResizedCrop(244),
        transforms.Normalize([0.485, 0.456, 0.406],[0.229, 0.224, 0.225]),
        transforms.RandomHorizontalFlip(),
        ]),
    "val": transforms.Compose([transforms.ToTensor(),
        transforms.CenterCrop(244),
        transforms.Normalize([0.485, 0.456, 0.406],[0.229, 0.224, 0.225]),
        transforms.Resize(256),
        ]),
     }                                     

In [6]:
import zipfile
with zipfile.ZipFile("./hymenoptera_data.zip", 'r') as zip_ref:
    zip_ref.extractall("./")

In [8]:
import shutil
shutil.rmtree("./hymenoptera_data/raw/")

In [9]:
data_dir = "./hymenoptera_data"
image_dataset = {x: datasets.ImageFolder(os.path.join(data_dir,x), 
                                         data_transform[x]) 
                 for x in ['train', 'val']}

In [15]:
image_dataset

{'train': Dataset ImageFolder
     Number of datapoints: 244
     Root location: ./hymenoptera_data/train
     StandardTransform
 Transform: Compose(
                ToTensor()
                RandomResizedCrop(size=(244, 244), scale=(0.08, 1.0), ratio=(0.75, 1.3333), interpolation=bilinear)
                Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
                RandomHorizontalFlip(p=0.5)
            ), 'val': Dataset ImageFolder
     Number of datapoints: 152
     Root location: ./hymenoptera_data/val
     StandardTransform
 Transform: Compose(
                ToTensor()
                CenterCrop(size=(244, 244))
                Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
                Resize(size=256, interpolation=bilinear, max_size=None, antialias=None)
            )}

In [10]:
dataloaders = {x: torch.utils.data.DataLoader(image_dataset[x], 
                                              batch_size = 4, shuffle = True)
               for x in ['train', 'val']}
dataset_sizes = {x: len(image_dataset[x]) for x in ['train', 'val']}
class_names  = image_dataset['train'].classes
use_gpu = torch.cuda.is_available()                         

In [11]:
def train_model(model, criterion, optimizer, scheduler, num_epochs=25):
    since = time.time()

    best_model_wts = model.state_dict()
    best_acc = 0.0
    
    #Ваш код здесь
    losses = {'train': [], 'val': []}

    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('-' * 10)

        # каждя эпоха имеет обучающую и тестовую стадии
        for phase in ['train', 'val']:
            if phase == 'train':
                scheduler.step()
                model.train(True)  # установаить модель в режим обучения
            else:
                model.train(False)  # установить модель в режим предсказания

            running_loss = 0.0
            running_corrects = 0

            # итерируемся по батчам
            for data in dataloaders[phase]:
                # получаем картинки и метки
                inputs, labels = data

                # оборачиваем в переменные
                if use_gpu:
                    inputs = inputs.cuda()
                    labels = labels.cuda()
                else:
                    inputs, labels = inputs, labels

                # инициализируем градиенты параметров
                optimizer.zero_grad()

                # forward pass
                outputs = model(inputs)
                _, preds = torch.max(outputs.data, 1)
                loss = criterion(outputs, labels)

                # backward pass + оптимизируем только если это стадия обучения
                if phase == 'train':
                    loss.backward()
                    optimizer.step()

                # статистика
                running_loss += loss.item()
                running_corrects += int(torch.sum(preds == labels.data))

            epoch_loss = running_loss / dataset_sizes[phase]
            epoch_acc = running_corrects / dataset_sizes[phase]
            
            # Ваш код здесь
            losses[phase].append(epoch_loss)
            
            print('{} Loss: {:.4f} Acc: {:.4f}'.format(
                phase, epoch_loss, epoch_acc))

            # если достиглось лучшее качество, то запомним веса модели
            if phase == 'val' and epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model_wts = model.state_dict()

        print()

    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(
        time_elapsed // 60, time_elapsed % 60))
    print('Best val Acc: {:4f}'.format(best_acc))

    # загрузим лучшие веса модели
    model.load_state_dict(best_model_wts)
    return model, losses

In [69]:
def visualize_model(model, num_images=6):
    images_so_far = 0
    fig = plt.figure()

    for i, data in enumerate(dataloaders['val']):
        inputs, labels = data
        if use_gpu:
            inputs, labels = Variable(inputs.cuda()), Variable(labels.cuda())
        else:
            inputs, labels = Variable(inputs), Variable(labels)

        outputs = model(inputs)
        _, preds = torch.max(outputs.data, 1)

        for j in range(inputs.size()[0]):
            images_so_far += 1
            ax = plt.subplot(num_images // 2, 2, images_so_far)
            ax.axis('off')
            ax.set_title('predicted: {}'.format(class_names[preds[j]]))
            imshow(inputs.cpu().data[j])

            if images_so_far == num_images:
                return

In [70]:
def validate():
    
    losses = []
    
    for inputs, labels in dataloaders['val']:
        
        outputs  = model(inputs)
        loss = criterion(outputs, labels)
        _, preds = torch.max(outputs.data, 1)
        running_loss += loss.item()
        running_corrects += int(torch.sum(preds == labels.data))
        
    _loss = running_loss / dataset_sizes[phase]
    _acc = running_corrects / dataset_sizes[phase]    

In [12]:
model = models.alexnet(pretrained = True)
print(model)

Downloading: "https://download.pytorch.org/models/alexnet-owt-7be5be79.pth" to /root/.cache/torch/hub/checkpoints/alexnet-owt-7be5be79.pth


  0%|          | 0.00/233M [00:00<?, ?B/s]

AlexNet(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(11, 11), stride=(4, 4), padding=(2, 2))
    (1): ReLU(inplace=True)
    (2): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(64, 192, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (4): ReLU(inplace=True)
    (5): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (6): Conv2d(192, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (7): ReLU(inplace=True)
    (8): Conv2d(384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (9): ReLU(inplace=True)
    (10): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (avgpool): AdaptiveAvgPool2d(output_size=(6, 6))
  (classifier): Sequential(
    (0): Dropout(p=0.5, inplace=False)
    (1): Linear(in_features=9216, out_features=4096, bias=True)
 

In [16]:
model.classifier = nn.Linear(in_features=9216, out_features = 2)

loss_fn = nn.CrossEntropyLoss()

optimizer_ft = optim.SGD(model.parameters(), lr = 0.001, momentum = 0.9)

exp_lr_scheduler = lr_scheduler.StepLR (optimizer_ft, step_size = 7, gamma = 0.1)

layers_to_unfreeze = 5

for param in model.features[: -layers_to_unfreeze]:
    param.requires_grad = False
   

In [14]:
model, losses = train_model(model, loss_fn, optimizer_ft, exp_lr_scheduler, 25)

Epoch 0/24
----------




RuntimeError: ignored

In [None]:
plt.plot(losses = ['train'], label ='train')
plt.plot(losses = ['val'], label ='val')