In [None]:
import torch
from torch import nn
from torchvision.datasets import CIFAR10
from torchvision import transforms
from torch.utils.data import DataLoader
from torch.optim import SGD

In [None]:
transform = transforms.Compose(
    [transforms.ToTensor(),])

dataset = CIFAR10("./", train=True, transform=transform, download=True)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./cifar-10-python.tar.gz


100%|██████████| 170M/170M [00:11<00:00, 14.9MB/s]


Extracting ./cifar-10-python.tar.gz to ./


In [None]:
loader = DataLoader(dataset, batch_size=128)

In [11]:
class MyNetwork(nn.Module):
  def __init__(self):
    super().__init__()
    self.layer_1 = nn.Linear(3 * 32 * 32, 512)
    self.layer_2 = nn.Linear(512, 256)
    self.layer_3 = nn.Linear(256, 128)
    self.layer_4 = nn.Linear(128, 10)

    self.relu = nn.ReLU()

  def forward(self, x):
    x_flat = x.view(x.shape[0], -1)

    y1 = self.layer_1(x_flat)
    y1 = self.relu(y1)
    y2 = self.layer_2(y1)
    y2 = self.relu(y2)
    y3 = self.layer_3(y2)
    y3 = self.relu(y3)
    y4 = self.layer_4(y3)

    return y4

network = MyNetwork()

In [10]:
opt = SGD(network.parameters(), lr=0.001)
loss_func = nn.CrossEntropyLoss()


for epoch in range(100):
  sum_loss = 0.
  for i, batch in enumerate(loader):
    inputs = batch[0]
    targets = batch[1]

    y = network(inputs)

    loss_value = loss_func(y, targets)

    loss_value.backward()

    opt.step()
    sum_loss += loss_value.item()
    print(f'Loss: {sum_loss / (i + 1)}')

Loss: 2.2990593910217285
Loss: 2.2970950603485107
Loss: 2.2992268403371177
Loss: 2.3019534945487976
Loss: 2.302722692489624
Loss: 2.3035189310709634
Loss: 2.3034962245396207
Loss: 2.303164631128311
Loss: 2.3033637205759683
Loss: 2.3040583610534666
Loss: 2.304275382648815
Loss: 2.3045963247617087
Loss: 2.303811843578632
Loss: 2.3042194162096297
Loss: 2.304727474848429
Loss: 2.3047951608896255
Loss: 2.3047147077672623
Loss: 2.304528554280599
Loss: 2.3044634743740686
Loss: 2.304676079750061
Loss: 2.3045151120140437
Loss: 2.3044512163509023
Loss: 2.3043889584748642
Loss: 2.3043006161848703
Loss: 2.3043224811553955
Loss: 2.304338519389813
Loss: 2.304399578659623
Loss: 2.304469440664564
Loss: 2.3045428210291368
Loss: 2.3043367226918536
Loss: 2.304303592251193
Loss: 2.3041993379592896
Loss: 2.3041096528371177
Loss: 2.3041265852311077
Loss: 2.304008974347796
Loss: 2.303854054874844
Loss: 2.303868126224827
Loss: 2.303811556414554
Loss: 2.3036622206370034
Loss: 2.3035715579986573
Loss: 2.3034680

KeyboardInterrupt: 

# Write Network Use Conv

In [20]:
class MyNetwork(nn.Module):
  def __init__(self):
    super().__init__()
    self.conv1 = nn.Conv2d(3, 32, (3, 3), (2, 2))
    self.conv2 = nn.Conv2d(32, 64, (3, 3), (2, 2))
    self.conv3 = nn.Conv2d(64, 256, (3, 3), (2, 2))

    self.linear1 = nn.Linear(2304, 128)
    self.linear2 = nn.Linear(128, 10)

    self.relu = nn.ReLU()

  def forward(self, x):
    y1 = self.conv1(x)
    y1 = self.relu(y1)
    y2 = self.conv2(y1)
    y2 = self.relu(y2)
    y3 = self.conv3(y2)
    y3 = self.relu(y3)


    # 128 x 256 x 3 x 3
    y_flat = y3.view(x.shape[0], -1)
    # 128 x 2304

    y = self.linear1(y_flat)
    y = self.relu(y)
    y = self.linear2(y)

    return y

network = MyNetwork()

x = torch.rand(10, 3, 32, 32)
y = network(x)
print(y.shape)

torch.Size([10, 10])


In [21]:
opt = SGD(network.parameters(), lr=0.001)
loss_func = nn.CrossEntropyLoss()


for epoch in range(100):
  sum_loss = 0.
  for i, batch in enumerate(loader):
    inputs = batch[0]
    targets = batch[1]

    y = network(inputs)

    loss_value = loss_func(y, targets)

    loss_value.backward()

    opt.step()
    sum_loss += loss_value.item()
    print(f'Loss: {sum_loss / (i + 1)}')

Loss: 2.3070592880249023
Loss: 2.3056787252426147
Loss: 2.3056536515553794
Loss: 2.3052929043769836
Loss: 2.3053404331207275
Loss: 2.3057889143625894
Loss: 2.304546696799142
Loss: 2.305057555437088
Loss: 2.3041237195332847
Loss: 2.3039314985275268
Loss: 2.303777889771895
Loss: 2.3036649227142334
Loss: 2.304012995499831
Loss: 2.303530522755214
Loss: 2.3040530999501545
Loss: 2.3037372082471848
Loss: 2.303592317244586
Loss: 2.3038199610180325
Loss: 2.303697159415797
Loss: 2.3036030888557435
Loss: 2.3035329523540677
Loss: 2.303495342081243
Loss: 2.303395758504453
Loss: 2.3033570845921836
Loss: 2.303124437332153
Loss: 2.30303300344027
Loss: 2.302936395009359
Loss: 2.302774829523904
Loss: 2.3027428265275627
Loss: 2.3027154366175333
Loss: 2.3027543790878786
Loss: 2.3027893602848053
Loss: 2.3027735912438594
Loss: 2.302811664693496
Loss: 2.3028412137712753
Loss: 2.302780661318037
Loss: 2.3027322936702417
Loss: 2.3025553665663065
Loss: 2.3026306812579813
Loss: 2.302661108970642
Loss: 2.302761828

KeyboardInterrupt: 