In [2]:
import numpy as np
import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.tensorboard import SummaryWriter

from model import GoogLeNet
from dataloader import load_dataset
from tqdm import tqdm

def train(model, device, train_loader, optimizer, epoch):
    model.train()
    for batch_idx, (data,target) in enumerate(train_loader):
        target = target.type(torch.LongTensor)
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        if model.aux_logits:
          loss0 = criterion(output[0], target)
          loss1 = criterion(output[1], target)
          loss2 = criterion(output[2], target)
          loss = loss0 + (0.3 * loss1) + (0.3 * loss2)
        else:
          loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        if batch_idx % 30 == 0:
            print(f"{batch_idx*len(data)}/{len(train_loader.dataset)}")

def test(model, device, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += criterion(output, target, reduction='mean').item()
            writer.add_scalar("Test Loss", test_loss, epoch)
            pred = output.argmax(1)
            correct += float((pred == target).sum())
            writer.add_scalar("Test Accuracy", correct, epoch)

        test_loss /= len(test_loader.dataset)
        correct /= len(test_loader.dataset)
        return test_loss, correct
        writer.close()

if __name__ == "__main__":

    num_epochs = 20
    learning_rate = 0.0001

    use_cuda = torch.cuda.is_available()
    print("use_cuda : ", use_cuda)
    device = torch.device("cuda:0" if use_cuda else "cpu")

    train_loader, test_loader = load_dataset()

    model = GoogLeNet().to(device)
    criterion = F.cross_entropy
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

    writer = SummaryWriter("./googlenet/tensorboard")

    for epoch in tqdm(range(1, num_epochs + 1)):
        train(model, device, train_loader, optimizer, epoch)
        test_loss, test_accuracy = test(model, device, test_loader)
        writer.add_scalar("Test Loss", test_loss, epoch)
        writer.add_scalar("Test Accuracy", test_accuracy, epoch)
        print(f"Processing Result = Epoch : {epoch}   Loss : {test_loss}   Accuracy : {test_accuracy}")
        writer.close()


    print(f"Result of GoogleNet = Epoch : {epoch}   Loss : {test_loss}   Accuracy : {test_accuracy}")

use_cuda :  True


100%|██████████| 170M/170M [00:13<00:00, 12.6MB/s]
  0%|          | 0/20 [00:00<?, ?it/s]

0/50000
3840/50000
7680/50000
11520/50000
15360/50000
19200/50000
23040/50000
26880/50000
30720/50000
34560/50000
38400/50000
42240/50000
46080/50000
31200/50000


  5%|▌         | 1/20 [03:39<1:09:39, 220.00s/it]

Processing Result = Epoch : 1   Loss : 0.009254997539520264   Accuracy : 0.5881
0/50000
3840/50000
7680/50000
11520/50000
15360/50000
19200/50000
23040/50000
26880/50000
30720/50000
34560/50000
38400/50000
42240/50000
46080/50000
31200/50000


 10%|█         | 2/20 [07:19<1:05:53, 219.66s/it]

Processing Result = Epoch : 2   Loss : 0.007037764567136765   Accuracy : 0.7012
0/50000
3840/50000
7680/50000
11520/50000
15360/50000
19200/50000
23040/50000
26880/50000
30720/50000
34560/50000
38400/50000
42240/50000
46080/50000
31200/50000


 15%|█▌        | 3/20 [10:58<1:02:07, 219.24s/it]

Processing Result = Epoch : 3   Loss : 0.005902400892972946   Accuracy : 0.7483
0/50000
3840/50000
7680/50000
11520/50000
15360/50000
19200/50000
23040/50000
26880/50000
30720/50000
34560/50000
38400/50000
42240/50000
46080/50000
31200/50000


 20%|██        | 4/20 [14:36<58:23, 218.99s/it]  

Processing Result = Epoch : 4   Loss : 0.005392308953404427   Accuracy : 0.7742
0/50000
3840/50000
7680/50000
11520/50000
15360/50000
19200/50000
23040/50000
26880/50000
30720/50000
34560/50000
38400/50000
42240/50000
46080/50000
31200/50000


 25%|██▌       | 5/20 [18:16<54:49, 219.31s/it]

Processing Result = Epoch : 5   Loss : 0.0054383423715829845   Accuracy : 0.7647
0/50000
3840/50000
7680/50000
11520/50000
15360/50000
19200/50000
23040/50000
26880/50000
30720/50000
34560/50000
38400/50000
42240/50000
46080/50000
31200/50000


 30%|███       | 6/20 [21:56<51:14, 219.64s/it]

Processing Result = Epoch : 6   Loss : 0.004700991803407669   Accuracy : 0.804
0/50000
3840/50000
7680/50000
11520/50000
15360/50000
19200/50000
23040/50000
26880/50000
30720/50000
34560/50000
38400/50000
42240/50000
46080/50000
31200/50000


 35%|███▌      | 7/20 [25:36<47:34, 219.59s/it]

Processing Result = Epoch : 7   Loss : 0.005020406407117844   Accuracy : 0.7931
0/50000
3840/50000
7680/50000
11520/50000
15360/50000
19200/50000
23040/50000
26880/50000
30720/50000
34560/50000
38400/50000
42240/50000
46080/50000
31200/50000


 40%|████      | 8/20 [29:15<43:54, 219.53s/it]

Processing Result = Epoch : 8   Loss : 0.0047625494688749315   Accuracy : 0.8065
0/50000
3840/50000
7680/50000
11520/50000
15360/50000
19200/50000
23040/50000
26880/50000
30720/50000
34560/50000
38400/50000
42240/50000
46080/50000
31200/50000


 45%|████▌     | 9/20 [32:54<40:12, 219.32s/it]

Processing Result = Epoch : 9   Loss : 0.005324128893017769   Accuracy : 0.803
0/50000
3840/50000
7680/50000
11520/50000
15360/50000
19200/50000
23040/50000
26880/50000
30720/50000
34560/50000
38400/50000
42240/50000
46080/50000
31200/50000


 50%|█████     | 10/20 [36:33<36:31, 219.15s/it]

Processing Result = Epoch : 10   Loss : 0.005458261877298355   Accuracy : 0.8011
0/50000
3840/50000
7680/50000
11520/50000
15360/50000
19200/50000
23040/50000
26880/50000
30720/50000
34560/50000
38400/50000
42240/50000
46080/50000
31200/50000


 55%|█████▌    | 11/20 [40:13<32:54, 219.40s/it]

Processing Result = Epoch : 11   Loss : 0.004914299017190933   Accuracy : 0.8153
0/50000
3840/50000
7680/50000
11520/50000
15360/50000
19200/50000
23040/50000
26880/50000
30720/50000
34560/50000
38400/50000
42240/50000
46080/50000
31200/50000


 60%|██████    | 12/20 [43:53<29:16, 219.56s/it]

Processing Result = Epoch : 12   Loss : 0.005430537310242653   Accuracy : 0.8065
0/50000
3840/50000
7680/50000
11520/50000
15360/50000
19200/50000
23040/50000
26880/50000
30720/50000
34560/50000
38400/50000
42240/50000
46080/50000
31200/50000


 65%|██████▌   | 13/20 [47:33<25:39, 219.86s/it]

Processing Result = Epoch : 13   Loss : 0.004706041076779365   Accuracy : 0.8284
0/50000
3840/50000
7680/50000
11520/50000
15360/50000
19200/50000
23040/50000
26880/50000
30720/50000
34560/50000
38400/50000
42240/50000
46080/50000
31200/50000


 70%|███████   | 14/20 [51:12<21:56, 219.49s/it]

Processing Result = Epoch : 14   Loss : 0.004623232109099627   Accuracy : 0.829
0/50000
3840/50000
7680/50000
11520/50000
15360/50000
19200/50000
23040/50000
26880/50000
30720/50000
34560/50000
38400/50000
42240/50000
46080/50000
31200/50000


 75%|███████▌  | 15/20 [54:51<18:16, 219.36s/it]

Processing Result = Epoch : 15   Loss : 0.004829658418893814   Accuracy : 0.8359
0/50000
3840/50000
7680/50000
11520/50000
15360/50000
19200/50000
23040/50000
26880/50000
30720/50000
34560/50000
38400/50000
42240/50000
46080/50000
31200/50000


 80%|████████  | 16/20 [58:31<14:38, 219.50s/it]

Processing Result = Epoch : 16   Loss : 0.005109167924523353   Accuracy : 0.8344
0/50000
3840/50000
7680/50000
11520/50000
15360/50000
19200/50000
23040/50000
26880/50000
30720/50000
34560/50000
38400/50000
42240/50000
46080/50000
31200/50000


 85%|████████▌ | 17/20 [1:02:11<10:58, 219.66s/it]

Processing Result = Epoch : 17   Loss : 0.004544503811001778   Accuracy : 0.8485
0/50000
3840/50000
7680/50000
11520/50000
15360/50000
19200/50000
23040/50000
26880/50000
30720/50000
34560/50000
38400/50000
42240/50000
46080/50000
31200/50000


 90%|█████████ | 18/20 [1:05:50<07:19, 219.63s/it]

Processing Result = Epoch : 18   Loss : 0.005240884357690811   Accuracy : 0.8366
0/50000
3840/50000
7680/50000
11520/50000
15360/50000
19200/50000
23040/50000
26880/50000
30720/50000
34560/50000
38400/50000
42240/50000
46080/50000
31200/50000


 95%|█████████▌| 19/20 [1:09:30<03:39, 219.65s/it]

Processing Result = Epoch : 19   Loss : 0.005490993082523346   Accuracy : 0.8399
0/50000
3840/50000
7680/50000
11520/50000
15360/50000
19200/50000
23040/50000
26880/50000
30720/50000
34560/50000
38400/50000
42240/50000
46080/50000
31200/50000


100%|██████████| 20/20 [1:13:09<00:00, 219.47s/it]

Processing Result = Epoch : 20   Loss : 0.0050972509205341335   Accuracy : 0.8391
Result of GoogleNet = Epoch : 20   Loss : 0.0050972509205341335   Accuracy : 0.8391



