In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchsummary import summary
! pip install tensorboardX
from tensorboardX import SummaryWriter

from tqdm import tqdm
from dataloader import load_dataset
from model import BottleNeck,ResNet


def train(model, device, train_loader, optimizer, epoch):
    model.train()
    for batch_idx, (data,target) in enumerate(train_loader):
        target = target.type(torch.LongTensor)
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        if (batch_idx + 1) % 30 == 0:
            print(f"{batch_idx*len(data)}/{len(train_loader.dataset)}")

def test(model, device, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += criterion(output, target).item()
            writer.add_scalar("Test Loss", test_loss, epoch)
            pred = output.argmax(1)
            correct += float((pred == target).sum())
            writer.add_scalar("Test Accuracy", correct, epoch)

        test_loss /= len(test_loader.dataset)
        correct /= len(test_loader.dataset)
        return test_loss, correct
        writer.close()

if __name__ == "__main__":

    num_epochs = 10
    learning_rate = 0.001

    trainloader, testloader = load_dataset()

    use_cuda = torch.cuda.is_available()
    print("use_cuda : ", use_cuda)
    device = torch.device("cuda:0" if use_cuda else "cpu")
    model = ResNet(BottleNeck).to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

    # Set Summary writer for Tensorboard visualization
    writer = SummaryWriter("./resnet/tensorboard")

    for epoch in tqdm(range(1, num_epochs + 1)):
        train(model, device, trainloader, optimizer, epoch)
        test_loss, test_accuracy = test(model, device, testloader)
        writer.add_scalar("Test Loss", test_loss, epoch)
        writer.add_scalar("Test Accuracy", test_accuracy, epoch)
        print(f"Processing Result = Epoch : {epoch}   Loss : {test_loss}   Accuracy : {test_accuracy}")
        writer.close()



100%|██████████| 170M/170M [00:12<00:00, 13.3MB/s]


use_cuda :  True


  0%|          | 0/10 [00:00<?, ?it/s]

3712/50000
7552/50000
11392/50000
15232/50000
19072/50000
22912/50000
26752/50000
30592/50000
34432/50000
38272/50000
42112/50000
45952/50000
49792/50000


 10%|█         | 1/10 [08:40<1:18:07, 520.88s/it]

Processing Result = Epoch : 1   Loss : 0.012732503044605254   Accuracy : 0.4327
3712/50000
7552/50000
11392/50000
15232/50000
19072/50000
22912/50000
26752/50000
30592/50000
34432/50000
38272/50000
42112/50000
45952/50000
49792/50000


 20%|██        | 2/10 [17:39<1:10:48, 531.05s/it]

Processing Result = Epoch : 2   Loss : 0.008196565794944764   Accuracy : 0.6265
3712/50000
7552/50000
11392/50000
15232/50000
19072/50000
22912/50000
26752/50000
30592/50000
34432/50000
38272/50000
42112/50000
45952/50000
49792/50000


 30%|███       | 3/10 [26:37<1:02:20, 534.35s/it]

Processing Result = Epoch : 3   Loss : 0.007867935943603516   Accuracy : 0.64
3712/50000
7552/50000
11392/50000
15232/50000
19072/50000
22912/50000
26752/50000
30592/50000
34432/50000
38272/50000
42112/50000
45952/50000
49792/50000


 40%|████      | 4/10 [35:35<53:35, 535.97s/it]  

Processing Result = Epoch : 4   Loss : 0.006857558864355087   Accuracy : 0.6971
3712/50000
7552/50000
11392/50000
15232/50000
19072/50000
22912/50000
26752/50000
30592/50000
34432/50000
38272/50000
42112/50000
45952/50000
49792/50000


 50%|█████     | 5/10 [44:34<44:44, 536.89s/it]

Processing Result = Epoch : 5   Loss : 0.006106131935119629   Accuracy : 0.7371
3712/50000
7552/50000
11392/50000
15232/50000
19072/50000
22912/50000
26752/50000
30592/50000
34432/50000
38272/50000
42112/50000
45952/50000
49792/50000


 60%|██████    | 6/10 [53:32<35:49, 537.40s/it]

Processing Result = Epoch : 6   Loss : 0.005059958317875862   Accuracy : 0.7758
3712/50000
7552/50000
11392/50000
15232/50000
19072/50000
22912/50000
26752/50000
30592/50000
34432/50000
38272/50000
42112/50000
45952/50000
49792/50000


 70%|███████   | 7/10 [1:02:31<26:53, 537.71s/it]

Processing Result = Epoch : 7   Loss : 0.005708589142560959   Accuracy : 0.7668
3712/50000
7552/50000
11392/50000
15232/50000
19072/50000
22912/50000
26752/50000
30592/50000
34432/50000
38272/50000
42112/50000
45952/50000
49792/50000


 80%|████████  | 8/10 [1:11:29<17:55, 537.91s/it]

Processing Result = Epoch : 8   Loss : 0.005004674953222275   Accuracy : 0.7823
3712/50000
7552/50000
11392/50000
15232/50000
19072/50000
22912/50000
26752/50000
30592/50000
34432/50000
38272/50000
42112/50000
45952/50000
49792/50000


 90%|█████████ | 9/10 [1:20:27<08:57, 537.89s/it]

Processing Result = Epoch : 9   Loss : 0.0051020045667886735   Accuracy : 0.7807
3712/50000
7552/50000
11392/50000
15232/50000
19072/50000
22912/50000
26752/50000
30592/50000
34432/50000
38272/50000
42112/50000
45952/50000
49792/50000


100%|██████████| 10/10 [1:29:24<00:00, 536.49s/it]

Processing Result = Epoch : 10   Loss : 0.004858658400177956   Accuracy : 0.8115



