In [None]:
import torch
from torch import nn
from torchvision import models, datasets, transforms
import time
from tqdm.auto import tqdm

: 

In [2]:
# Обязательно к прочтению: тред на тему различных состояний нейронной сети в PyTorch
# https://stackoverflow.com/questions/51748138/pytorch-how-to-set-requires-grad-false
def set_requires_grad(model, value=False):
    for param in model.parameters():
        param.requires_grad = value

In [3]:
weights = models.ResNet18_Weights.DEFAULT
model = models.resnet18(weights=weights)

Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /home/ubuntu/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth
100%|██████████| 44.7M/44.7M [00:00<00:00, 76.7MB/s]


In [4]:
model.fc

Linear(in_features=512, out_features=1000, bias=True)

In [5]:
transforms = weights.transforms()
transforms

ImageClassification(
    crop_size=[224]
    resize_size=[256]
    mean=[0.485, 0.456, 0.406]
    std=[0.229, 0.224, 0.225]
    interpolation=InterpolationMode.BILINEAR
)

In [6]:
weights.meta['categories'][:10]

['tench',
 'goldfish',
 'great white shark',
 'tiger shark',
 'hammerhead',
 'electric ray',
 'stingray',
 'cock',
 'hen',
 'ostrich']

In [7]:
#создаём свой кастомный слой классификации
num_classes = 10
num_in_features = model.fc.in_features

set_requires_grad(model, False)
model.fc = nn.Linear(num_in_features, num_classes)

In [8]:
next(model.fc.parameters()).requires_grad

True

In [9]:
trainset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transforms)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True, num_workers=2)

testset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transforms)
testloader = torch.utils.data.DataLoader(testset, batch_size=64, shuffle=False, num_workers=2)

loaders = {'train': trainloader, 'val': testloader}
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [00:22<00:00, 7618367.70it/s]


Extracting ./data/cifar-10-python.tar.gz to ./data
Files already downloaded and verified


In [11]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device

device(type='cuda', index=0)

In [13]:
def train_model(model, dataloaders, criterion, optimizer, phases, num_epochs=3):
    start_time = time.time()

    acc_history = {k: list() for k in phases}
    loss_history = {k: list() for k in phases}

    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('-' * 10)

        # Each epoch has a training and validation phase
        for phase in phases:
            if phase == 'train':
                model.train()  # Set model to training mode
            else:
                model.eval()   # Set model to evaluate mode

            running_loss = 0.0
            running_corrects = 0

            # Iterate over data.
            n_batches = len(dataloaders[phase])
            for inputs, labels in tqdm(dataloaders[phase], total=n_batches):
                inputs = inputs.to(device)
                labels = labels.to(device)

                # zero the parameter gradients
                optimizer.zero_grad()

                # forward
                # track history if only in train
                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)
                    loss = criterion(outputs, labels)

                    preds = torch.argmax(outputs, 1)

                    # backward + optimize only if in training phase
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                # statistics
                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)

            epoch_loss = running_loss / len(dataloaders[phase].dataset)
            epoch_acc = running_corrects.double()
            epoch_acc /= len(dataloaders[phase].dataset)

            print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss,
                                                       epoch_acc))
            loss_history[phase].append(epoch_loss)
            acc_history[phase].append(epoch_acc)

        print()

    time_elapsed = time.time() - start_time
    print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60,
                                                        time_elapsed % 60))

    return model, acc_history

# Дообучение модели - только последний слой

In [14]:
criterion = nn.CrossEntropyLoss()

optimizer = torch.optim.AdamW(model.parameters(), lr=0.01)

set_requires_grad(model, False)
set_requires_grad(model.fc, True)

model = model.to(device)
train_model(model, loaders, criterion, optimizer, phases=['train', 'val'], num_epochs=3)

Epoch 0/2
----------


  0%|          | 0/782 [00:00<?, ?it/s]

train Loss: 0.9473 Acc: 0.6954


  0%|          | 0/157 [00:00<?, ?it/s]

val Loss: 0.9239 Acc: 0.7243

Epoch 1/2
----------


  0%|          | 0/782 [00:00<?, ?it/s]

train Loss: 0.8949 Acc: 0.7266


  0%|          | 0/157 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7431a82b3600>
Traceback (most recent call last):
  File "/home/ubuntu/env/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1477, in __del__
    self._shutdown_workers()
  File "/home/ubuntu/env/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1460, in _shutdown_workers
    if w.is_alive():
       ^^^^^^^Exception ignored in: ^<function _MultiProcessingDataLoaderIter.__del__ at 0x7431a82b3600>^
^Traceback (most recent call last):
^  File "/home/ubuntu/env/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1477, in __del__
^
      File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
    self._shutdown_workers()assert self._parent_pid == os.getpid(), 'can only test a child process'

  File "/home/ubuntu/env/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1460, in _shutdown_workers
      if w.is_alive(): 
               ^

val Loss: 0.9844 Acc: 0.7079

Epoch 2/2
----------


  0%|          | 0/782 [00:00<?, ?it/s]

Exception ignored in: Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7431a82b3600><function _MultiProcessingDataLoaderIter.__del__ at 0x7431a82b3600>
Traceback (most recent call last):
  File "/home/ubuntu/env/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1477, in __del__
    self._shutdown_workers()

  File "/home/ubuntu/env/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1460, in _shutdown_workers
Traceback (most recent call last):
      File "/home/ubuntu/env/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1477, in __del__
if w.is_alive():    
self._shutdown_workers() 
   File "/home/ubuntu/env/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1460, in _shutdown_workers
     if w.is_alive():
          ^ ^^^^^^^^^^^^^^^^^^^^^^
^  File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
    
  File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in i

train Loss: 0.8755 Acc: 0.7337


  0%|          | 0/157 [00:00<?, ?it/s]

val Loss: 0.8861 Acc: 0.7338

Training complete in 3m 47s


(ResNet(
   (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
   (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
   (relu): ReLU(inplace=True)
   (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
   (layer1): Sequential(
     (0): BasicBlock(
       (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
       (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
       (relu): ReLU(inplace=True)
       (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
       (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
     )
     (1): BasicBlock(
       (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
       (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
       (relu): ReLU

# Дообучение модели - все слои

In [15]:
set_requires_grad(model, True)

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=0.01)

model = model.to(device)
train_model(model, loaders, criterion, optimizer, phases=['train', 'val'], num_epochs=3)

Epoch 0/2
----------


  0%|          | 0/782 [00:00<?, ?it/s]

train Loss: 2.0922 Acc: 0.2989


  0%|          | 0/157 [00:00<?, ?it/s]

val Loss: 3.1476 Acc: 0.4000

Epoch 1/2
----------


  0%|          | 0/782 [00:00<?, ?it/s]

train Loss: 1.3294 Acc: 0.5236


  0%|          | 0/157 [00:00<?, ?it/s]

val Loss: 1.0948 Acc: 0.6109

Epoch 2/2
----------


  0%|          | 0/782 [00:00<?, ?it/s]

train Loss: 0.9229 Acc: 0.6765


  0%|          | 0/157 [00:00<?, ?it/s]

val Loss: 0.8523 Acc: 0.7057

Training complete in 7m 38s


(ResNet(
   (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
   (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
   (relu): ReLU(inplace=True)
   (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
   (layer1): Sequential(
     (0): BasicBlock(
       (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
       (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
       (relu): ReLU(inplace=True)
       (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
       (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
     )
     (1): BasicBlock(
       (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
       (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
       (relu): ReLU

In [16]:
model

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  

In [32]:
set_requires_grad(model, False)
set_requires_grad(model.layer4, True)
set_requires_grad(model.fc, True)

criterion = nn.CrossEntropyLoss()

optimizer = torch.optim.AdamW(model.parameters(), lr=0.01)

model = model.to(device)
train_model(model, loaders, criterion, optimizer, phases=['train', 'val'], num_epochs=3)

Epoch 0/2
----------


  0%|          | 0/782 [00:00<?, ?it/s]

KeyboardInterrupt: 