## Feed-forward сети

In [None]:
# !pip install torchvision

In [1]:
# Итак, давайте потренируемся тренировать нейронные сети прямого распространения (так как делали на паре)
# При этом попробуем создать свою функцию активации на одном из слоев
# Сделаем необходимые импорты

In [2]:
import torch
import numpy as np

from torch import nn
from torch import optim
import torchvision
import torchvision.transforms as transforms
from tqdm import tqdm
import torch.nn.functional as F

In [3]:
# Загрузим датасет CIFAR-100, сразу же создадим dataloader для него
# Если вам не хватает вычислительных ресурсов, то можно вернуться к CIFAR-10

In [4]:
transform = transforms.Compose([transforms.ToTensor(),
                                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
BATCH_SIZE = 4

train_set = torchvision.datasets.CIFAR100(root='./data', train=True, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)

test_set = torchvision.datasets.CIFAR100(root='./data', train=False, download=True, transform=transform)
test_loader = torch.utils.data.DataLoader(test_set, batch_size = BATCH_SIZE, shuffle=False, num_workers=2)

classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

Files already downloaded and verified
Files already downloaded and verified


In [5]:
# Создайте собственную архитектуру! Можете использовать все, что угодно, но в ограничении на использование линейные слои (пока без сверток)
# Давайте добавим ограниченный Leaky_relu, то есть output = max(0.1x, 0.5x)
# Ваша задача добавить его в архитектуру сети как функцию активации

In [6]:
class LeakyRelu(nn.Module):
    def __init__(self, border:float, alpha:float) -> None:
        super().__init__()
        self.border = border
        self.alpha = alpha
        
    def forward(self, input):
        x = F.leaky_relu(input)
        return torch.where(x > self.border, x * self.alpha, x)

class Net(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        img_vector_size = 3072
        self.fc1 = nn.Linear(img_vector_size,128)
        self.fc2 = nn.Linear(128,64)
        self.fc3 = nn.Linear(64,100)
        self.cust_relu = LeakyRelu(0.1, 0.5)
        
    def forward(self, x):
        x = x.view(x.shape[0], -1)        
        x = self.cust_relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x
    
    def predict(self, x):
        x = x.view(x.shape[0], -1)
        
        x = self.cust_relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = F.softmax(self.fc3(x))
        return x
        
net = Net()
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(net.parameters(), lr=0.001)

In [7]:
# Запустить обучение (по аналогии с тем, что делали на паре)

In [8]:
for epoch in tqdm(range(10)):  
    running_loss = 0.0 
    for i, data in enumerate(train_loader, 0):
        inputs, labels = data[0], data[1]

        # обнуляем градиент
        optimizer.zero_grad()

        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        # выводим статистику о процессе обучения
        running_loss += loss.item()
        if i % 300 == 0:    # печатаем каждые 300 mini-batches
            print('[%d, %5d] loss: %.3f' %
                  (epoch + 1, i + 1, running_loss / 2000))
            running_loss = 0.0

print('Training is finished!')

  0%|                                                    | 0/10 [00:00<?, ?it/s]

[1,     1] loss: 0.002
[1,   301] loss: 0.681
[1,   601] loss: 0.658
[1,   901] loss: 0.632
[1,  1201] loss: 0.629
[1,  1501] loss: 0.617
[1,  1801] loss: 0.606
[1,  2101] loss: 0.604
[1,  2401] loss: 0.605
[1,  2701] loss: 0.612
[1,  3001] loss: 0.595
[1,  3301] loss: 0.591
[1,  3601] loss: 0.576
[1,  3901] loss: 0.596
[1,  4201] loss: 0.589
[1,  4501] loss: 0.581
[1,  4801] loss: 0.577
[1,  5101] loss: 0.574
[1,  5401] loss: 0.582
[1,  5701] loss: 0.590
[1,  6001] loss: 0.572
[1,  6301] loss: 0.570
[1,  6601] loss: 0.576
[1,  6901] loss: 0.574
[1,  7201] loss: 0.571
[1,  7501] loss: 0.571
[1,  7801] loss: 0.568
[1,  8101] loss: 0.573
[1,  8401] loss: 0.578
[1,  8701] loss: 0.573
[1,  9001] loss: 0.568
[1,  9301] loss: 0.565
[1,  9601] loss: 0.566
[1,  9901] loss: 0.563
[1, 10201] loss: 0.549
[1, 10501] loss: 0.567
[1, 10801] loss: 0.573
[1, 11101] loss: 0.550
[1, 11401] loss: 0.565
[1, 11701] loss: 0.551
[1, 12001] loss: 0.566
[1, 12301] loss: 0.567


 10%|████▍                                       | 1/10 [01:04<09:44, 64.97s/it]

[2,     1] loss: 0.002
[2,   301] loss: 0.542
[2,   601] loss: 0.547
[2,   901] loss: 0.556
[2,  1201] loss: 0.550
[2,  1501] loss: 0.556
[2,  1801] loss: 0.540
[2,  2101] loss: 0.544
[2,  2401] loss: 0.540
[2,  2701] loss: 0.545
[2,  3001] loss: 0.556
[2,  3301] loss: 0.554
[2,  3601] loss: 0.549
[2,  3901] loss: 0.532
[2,  4201] loss: 0.558
[2,  4501] loss: 0.560
[2,  4801] loss: 0.552
[2,  5101] loss: 0.542
[2,  5401] loss: 0.564
[2,  5701] loss: 0.541
[2,  6001] loss: 0.548
[2,  6301] loss: 0.559
[2,  6601] loss: 0.553
[2,  6901] loss: 0.549
[2,  7201] loss: 0.554
[2,  7501] loss: 0.555
[2,  7801] loss: 0.549
[2,  8101] loss: 0.542
[2,  8401] loss: 0.543
[2,  8701] loss: 0.550
[2,  9001] loss: 0.553
[2,  9301] loss: 0.543
[2,  9601] loss: 0.546
[2,  9901] loss: 0.544
[2, 10201] loss: 0.545
[2, 10501] loss: 0.550
[2, 10801] loss: 0.555
[2, 11101] loss: 0.550
[2, 11401] loss: 0.552
[2, 11701] loss: 0.545
[2, 12001] loss: 0.549
[2, 12301] loss: 0.545


 20%|████████▊                                   | 2/10 [02:09<08:35, 64.45s/it]

[3,     1] loss: 0.002
[3,   301] loss: 0.529
[3,   601] loss: 0.519
[3,   901] loss: 0.544
[3,  1201] loss: 0.521
[3,  1501] loss: 0.547
[3,  1801] loss: 0.533
[3,  2101] loss: 0.538
[3,  2401] loss: 0.529
[3,  2701] loss: 0.532
[3,  3001] loss: 0.550
[3,  3301] loss: 0.535
[3,  3601] loss: 0.537
[3,  3901] loss: 0.530
[3,  4201] loss: 0.537
[3,  4501] loss: 0.540
[3,  4801] loss: 0.537
[3,  5101] loss: 0.534
[3,  5401] loss: 0.531
[3,  5701] loss: 0.527
[3,  6001] loss: 0.536
[3,  6301] loss: 0.535
[3,  6601] loss: 0.551
[3,  6901] loss: 0.546
[3,  7201] loss: 0.534
[3,  7501] loss: 0.535
[3,  7801] loss: 0.545
[3,  8101] loss: 0.547
[3,  8401] loss: 0.542
[3,  8701] loss: 0.543
[3,  9001] loss: 0.540
[3,  9301] loss: 0.543
[3,  9601] loss: 0.542
[3,  9901] loss: 0.543
[3, 10201] loss: 0.541
[3, 10501] loss: 0.549
[3, 10801] loss: 0.534
[3, 11101] loss: 0.545
[3, 11401] loss: 0.541
[3, 11701] loss: 0.531
[3, 12001] loss: 0.529
[3, 12301] loss: 0.536


 30%|█████████████▏                              | 3/10 [03:13<07:29, 64.25s/it]

[4,     1] loss: 0.001
[4,   301] loss: 0.520
[4,   601] loss: 0.521
[4,   901] loss: 0.517
[4,  1201] loss: 0.522
[4,  1501] loss: 0.538
[4,  1801] loss: 0.539
[4,  2101] loss: 0.529
[4,  2401] loss: 0.541
[4,  2701] loss: 0.532
[4,  3001] loss: 0.521
[4,  3301] loss: 0.527
[4,  3601] loss: 0.533
[4,  3901] loss: 0.547
[4,  4201] loss: 0.531
[4,  4501] loss: 0.530
[4,  4801] loss: 0.533
[4,  5101] loss: 0.533
[4,  5401] loss: 0.535
[4,  5701] loss: 0.536
[4,  6001] loss: 0.529
[4,  6301] loss: 0.537
[4,  6601] loss: 0.525
[4,  6901] loss: 0.538
[4,  7201] loss: 0.534
[4,  7501] loss: 0.538
[4,  7801] loss: 0.533
[4,  8101] loss: 0.518
[4,  8401] loss: 0.545
[4,  8701] loss: 0.527
[4,  9001] loss: 0.536
[4,  9301] loss: 0.534
[4,  9601] loss: 0.536
[4,  9901] loss: 0.533
[4, 10201] loss: 0.541
[4, 10501] loss: 0.526
[4, 10801] loss: 0.549
[4, 11101] loss: 0.532
[4, 11401] loss: 0.544
[4, 11701] loss: 0.529
[4, 12001] loss: 0.540
[4, 12301] loss: 0.536


 40%|█████████████████▌                          | 4/10 [04:16<06:23, 63.92s/it]

[5,     1] loss: 0.002
[5,   301] loss: 0.523
[5,   601] loss: 0.510
[5,   901] loss: 0.535
[5,  1201] loss: 0.520
[5,  1501] loss: 0.536
[5,  1801] loss: 0.533
[5,  2101] loss: 0.527
[5,  2401] loss: 0.534
[5,  2701] loss: 0.532
[5,  3001] loss: 0.527
[5,  3301] loss: 0.526
[5,  3601] loss: 0.524
[5,  3901] loss: 0.531
[5,  4201] loss: 0.515
[5,  4501] loss: 0.514
[5,  4801] loss: 0.529
[5,  5101] loss: 0.537
[5,  5401] loss: 0.533
[5,  5701] loss: 0.537
[5,  6001] loss: 0.525
[5,  6301] loss: 0.520
[5,  6601] loss: 0.549
[5,  6901] loss: 0.528
[5,  7201] loss: 0.510
[5,  7501] loss: 0.529
[5,  7801] loss: 0.534
[5,  8101] loss: 0.531
[5,  8401] loss: 0.519
[5,  8701] loss: 0.523
[5,  9001] loss: 0.534
[5,  9301] loss: 0.536
[5,  9601] loss: 0.530
[5,  9901] loss: 0.531
[5, 10201] loss: 0.531
[5, 10501] loss: 0.550
[5, 10801] loss: 0.536
[5, 11101] loss: 0.526
[5, 11401] loss: 0.517
[5, 11701] loss: 0.530
[5, 12001] loss: 0.536
[5, 12301] loss: 0.535


 50%|██████████████████████                      | 5/10 [05:17<05:15, 63.01s/it]

[6,     1] loss: 0.002
[6,   301] loss: 0.522
[6,   601] loss: 0.515
[6,   901] loss: 0.516
[6,  1201] loss: 0.522
[6,  1501] loss: 0.519
[6,  1801] loss: 0.513
[6,  2101] loss: 0.524
[6,  2401] loss: 0.516
[6,  2701] loss: 0.506
[6,  3001] loss: 0.520
[6,  3301] loss: 0.529
[6,  3601] loss: 0.521
[6,  3901] loss: 0.524
[6,  4201] loss: 0.530
[6,  4501] loss: 0.520
[6,  4801] loss: 0.523
[6,  5101] loss: 0.520
[6,  5401] loss: 0.537
[6,  5701] loss: 0.528
[6,  6001] loss: 0.522
[6,  6301] loss: 0.535
[6,  6601] loss: 0.522
[6,  6901] loss: 0.516
[6,  7201] loss: 0.538
[6,  7501] loss: 0.533
[6,  7801] loss: 0.527
[6,  8101] loss: 0.521
[6,  8401] loss: 0.532
[6,  8701] loss: 0.538
[6,  9001] loss: 0.532
[6,  9301] loss: 0.531
[6,  9601] loss: 0.531
[6,  9901] loss: 0.528
[6, 10201] loss: 0.541
[6, 10501] loss: 0.525
[6, 10801] loss: 0.525
[6, 11101] loss: 0.512
[6, 11401] loss: 0.529
[6, 11701] loss: 0.536
[6, 12001] loss: 0.536
[6, 12301] loss: 0.534


 60%|██████████████████████████▍                 | 6/10 [06:19<04:10, 62.69s/it]

[7,     1] loss: 0.002
[7,   301] loss: 0.515
[7,   601] loss: 0.509
[7,   901] loss: 0.506
[7,  1201] loss: 0.524
[7,  1501] loss: 0.519
[7,  1801] loss: 0.498
[7,  2101] loss: 0.522
[7,  2401] loss: 0.514
[7,  2701] loss: 0.522
[7,  3001] loss: 0.523
[7,  3301] loss: 0.535
[7,  3601] loss: 0.522
[7,  3901] loss: 0.532
[7,  4201] loss: 0.520
[7,  4501] loss: 0.530
[7,  4801] loss: 0.528
[7,  5101] loss: 0.509
[7,  5401] loss: 0.527
[7,  5701] loss: 0.527
[7,  6001] loss: 0.533
[7,  6301] loss: 0.512
[7,  6601] loss: 0.520
[7,  6901] loss: 0.530
[7,  7201] loss: 0.496
[7,  7501] loss: 0.514
[7,  7801] loss: 0.539
[7,  8101] loss: 0.526
[7,  8401] loss: 0.528
[7,  8701] loss: 0.526
[7,  9001] loss: 0.538
[7,  9301] loss: 0.526
[7,  9601] loss: 0.518
[7,  9901] loss: 0.513
[7, 10201] loss: 0.537
[7, 10501] loss: 0.524
[7, 10801] loss: 0.534
[7, 11101] loss: 0.525
[7, 11401] loss: 0.525
[7, 11701] loss: 0.525
[7, 12001] loss: 0.532
[7, 12301] loss: 0.529


 70%|██████████████████████████████▊             | 7/10 [07:23<03:08, 62.86s/it]

[8,     1] loss: 0.001
[8,   301] loss: 0.526
[8,   601] loss: 0.505
[8,   901] loss: 0.518
[8,  1201] loss: 0.519
[8,  1501] loss: 0.507
[8,  1801] loss: 0.506
[8,  2101] loss: 0.502
[8,  2401] loss: 0.517
[8,  2701] loss: 0.515
[8,  3001] loss: 0.519
[8,  3301] loss: 0.516
[8,  3601] loss: 0.515
[8,  3901] loss: 0.509
[8,  4201] loss: 0.524
[8,  4501] loss: 0.517
[8,  4801] loss: 0.526
[8,  5101] loss: 0.504
[8,  5401] loss: 0.511
[8,  5701] loss: 0.528
[8,  6001] loss: 0.534
[8,  6301] loss: 0.521
[8,  6601] loss: 0.516
[8,  6901] loss: 0.529
[8,  7201] loss: 0.512
[8,  7501] loss: 0.520
[8,  7801] loss: 0.514
[8,  8101] loss: 0.511
[8,  8401] loss: 0.530
[8,  8701] loss: 0.538
[8,  9001] loss: 0.522
[8,  9301] loss: 0.524
[8,  9601] loss: 0.523
[8,  9901] loss: 0.520
[8, 10201] loss: 0.536
[8, 10501] loss: 0.517
[8, 10801] loss: 0.527
[8, 11101] loss: 0.530
[8, 11401] loss: 0.525
[8, 11701] loss: 0.520
[8, 12001] loss: 0.535
[8, 12301] loss: 0.523


 80%|███████████████████████████████████▏        | 8/10 [08:25<02:05, 62.61s/it]

[9,     1] loss: 0.001
[9,   301] loss: 0.509
[9,   601] loss: 0.520
[9,   901] loss: 0.501
[9,  1201] loss: 0.509
[9,  1501] loss: 0.511
[9,  1801] loss: 0.514
[9,  2101] loss: 0.530
[9,  2401] loss: 0.515
[9,  2701] loss: 0.526
[9,  3001] loss: 0.520
[9,  3301] loss: 0.508
[9,  3601] loss: 0.517
[9,  3901] loss: 0.519
[9,  4201] loss: 0.511
[9,  4501] loss: 0.528
[9,  4801] loss: 0.521
[9,  5101] loss: 0.513
[9,  5401] loss: 0.525
[9,  5701] loss: 0.524
[9,  6001] loss: 0.511
[9,  6301] loss: 0.521
[9,  6601] loss: 0.502
[9,  6901] loss: 0.519
[9,  7201] loss: 0.527
[9,  7501] loss: 0.516
[9,  7801] loss: 0.529
[9,  8101] loss: 0.504
[9,  8401] loss: 0.532
[9,  8701] loss: 0.531
[9,  9001] loss: 0.529
[9,  9301] loss: 0.521
[9,  9601] loss: 0.509
[9,  9901] loss: 0.518
[9, 10201] loss: 0.523
[9, 10501] loss: 0.522
[9, 10801] loss: 0.525
[9, 11101] loss: 0.522
[9, 11401] loss: 0.519
[9, 11701] loss: 0.519
[9, 12001] loss: 0.508
[9, 12301] loss: 0.527


 90%|███████████████████████████████████████▌    | 9/10 [09:27<01:02, 62.41s/it]

[10,     1] loss: 0.002
[10,   301] loss: 0.494
[10,   601] loss: 0.516
[10,   901] loss: 0.506
[10,  1201] loss: 0.511
[10,  1501] loss: 0.504
[10,  1801] loss: 0.515
[10,  2101] loss: 0.517
[10,  2401] loss: 0.507
[10,  2701] loss: 0.520
[10,  3001] loss: 0.499
[10,  3301] loss: 0.516
[10,  3601] loss: 0.532
[10,  3901] loss: 0.517
[10,  4201] loss: 0.516
[10,  4501] loss: 0.516
[10,  4801] loss: 0.527
[10,  5101] loss: 0.505
[10,  5401] loss: 0.507
[10,  5701] loss: 0.515
[10,  6001] loss: 0.530
[10,  6301] loss: 0.534
[10,  6601] loss: 0.520
[10,  6901] loss: 0.514
[10,  7201] loss: 0.526
[10,  7501] loss: 0.525
[10,  7801] loss: 0.508
[10,  8101] loss: 0.505
[10,  8401] loss: 0.515
[10,  8701] loss: 0.517
[10,  9001] loss: 0.515
[10,  9301] loss: 0.519
[10,  9601] loss: 0.528
[10,  9901] loss: 0.515
[10, 10201] loss: 0.524
[10, 10501] loss: 0.523
[10, 10801] loss: 0.524
[10, 11101] loss: 0.518
[10, 11401] loss: 0.522
[10, 11701] loss: 0.522
[10, 12001] loss: 0.515
[10, 12301] loss

100%|███████████████████████████████████████████| 10/10 [10:28<00:00, 62.88s/it]

Training is finished!



