In [22]:
import wandb
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchsummary

# torchvision
from torchvision import datasets, transforms
from torchvision.datasets import ImageFolder
from torch.utils.data.sampler import SubsetRandomSampler
from torch.utils.data import Dataset, DataLoader, random_split


# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [23]:
class AlexNet(nn.Module):
    def __init__(self):
        super(AlexNet, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=96, kernel_size=11, stride=4)
        self.conv2 = nn.Conv2d(in_channels=96, out_channels=256, kernel_size=5, padding=2)
        self.conv3 = nn.Conv2d(in_channels=256, out_channels=384, kernel_size=3, padding=1)
        self.conv4 = nn.Conv2d(in_channels=384, out_channels=384, kernel_size=3, padding=1)
        self.conv5 = nn.Conv2d(in_channels=384, out_channels=256, kernel_size=3, padding=1) # conv5와 fc1 사이에 view 들어간다.
        self.fc1 = nn.Linear(in_features=256 * 5 * 5, out_features=4096) # fc layer 6 * 6 -> 5 * 5
        self.fc2 = nn.Linear(in_features=4096, out_features=4096)
        self.fc3 = nn.Linear(in_features=4096, out_features=4)
    
    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, kernel_size=3, stride=2)
        
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, kernel_size=3, stride=2)
        
        x = F.relu(self.conv3(x))
        x = F.relu(self.conv4(x))
        x = F.relu(self.conv5(x))
        x = F.max_pool2d(x, kernel_size=3, stride=2)
        
        x = x.view(x.size(0), -1) # 4차원을 1차원으로 펼쳐주는 층 (역할) -> flatten
        
        x = F.relu(self.fc1(x))
        x = F.dropout(x, p=0.5)
        
        x = F.relu(self.fc2(x))
        x = F.dropout(x, p=0.5)
    
        x = F.log_softmax(self.fc3(x), dim=1)
        
        return x 

In [24]:
torchsummary.summary(model, (3, 224, 224))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 96, 54, 54]          34,944
            Conv2d-2          [-1, 256, 26, 26]         614,656
            Conv2d-3          [-1, 384, 12, 12]         885,120
            Conv2d-4          [-1, 384, 12, 12]       1,327,488
            Conv2d-5          [-1, 256, 12, 12]         884,992
            Linear-6                 [-1, 4096]      26,218,496
            Linear-7                 [-1, 4096]      16,781,312
            Linear-8                    [-1, 4]          16,388
Total params: 46,763,396
Trainable params: 46,763,396
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.57
Forward/backward pass size (MB): 4.64
Params size (MB): 178.39
Estimated Total Size (MB): 183.61
----------------------------------------------------------------


In [18]:
wandb.login()

[34m[1mwandb[0m: Currently logged in as: [33mdongju_coredottoday[0m ([33mcoredottoday-dongju[0m). Use [1m`wandb login --relogin`[0m to force relogin


True

In [3]:
# imagefloder 생성
train_imgfolder = ImageFolder(root='data/train',
                      transform=transforms.Compose([
                          transforms.ToTensor(),
                          transforms.Resize(224),
                          transforms.RandomHorizontalFlip(p=0.8),
                          transforms.GaussianBlur(kernel_size=(19, 19), sigma=(1.0, 2.0)),
                          transforms.RandomRotation(degrees=(-30, 30), interpolation=transforms.InterpolationMode.BILINEAR, fill=0),
                          transforms.CenterCrop(224)
                      ]))


validation_imgfolder = ImageFolder(root='data/validation',
                          transform=transforms.Compose([
                          transforms.ToTensor(),
                          transforms.Resize(224),
                          transforms.CenterCrop(224)
                      ]))


test_imgfolder = ImageFolder(root='data/test',
                      transform=transforms.Compose([
                          transforms.ToTensor(),
                          transforms.Resize(224),
                          transforms.CenterCrop(224)
                      ]))

In [13]:
# get data_loader
train_data_loader = DataLoader(dataset=train_imgfolder, 
                         batch_size=16, 
                         num_workers=0,
                         shuffle=True,
                         drop_last=True
                        )

valid_data_loader = DataLoader(dataset=validation_imgfolder, 
                         batch_size=16, 
                         num_workers=0
                        )

test_data_loader = DataLoader(dataset=test_imgfolder, 
                         batch_size=1, 
                         num_workers=0
                        )

In [14]:
def normal_init(m):
    if isinstance(m, nn.Conv2d):
        torch.nn.init.normal_(m.weight, mean=0, std=0.01)
        if m.bias is not None:
            torch.nn.init.constant_(m.bias, 0)
    elif isinstance(m, nn.Linear):
        torch.nn.init.normal_(m.weight, mean=0, std=0.01)
        if m.bias is not None:
            torch.nn.init.constant_(m.bias, 0)
    elif isinstance(m, nn.BatchNorm2d):
        torch.nn.init.constant_(m.weight, 1)
        torch.nn.init.constant_(m.bias, 0)

In [20]:
num_classes = 4
num_epochs = 100
batch_size = 16
learning_rate = 0.0001

model = AlexNet().to(device)
model.apply(normal_init)

# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay = 0.0001)  


# Train the model
total_step = len(train_data_loader)

In [19]:
run = wandb.init(
    project="pbl",
    config={
        "learning_rate": learning_rate,
        "eprchs": num_epochs
    }
)

In [16]:
train_datanum = len(train_data_loader)*batch_size
valid_datanum = len(valid_data_loader)*batch_size

In [21]:
# 학습 진행
for epoch in range(num_epochs):
    loss_sum = 0
    correct = 0
    model.train()
    for train_x, train_y in train_data_loader:
        optimizer.zero_grad()
        output = model(train_x.cuda())
        
        loss = criterion(output.cpu(), train_y)
        
        loss.backward()
        optimizer.step()

        loss_sum += loss.item()
        


        predicted = torch.max(output, 1)[1]
        correct += (train_y == predicted.cpu()).sum()

    print(f'--------------------------------------------------------')
    print(f'Epoch [{epoch+1}/{num_epochs}] Loss: {loss_sum/len(train_data_loader):.4f} Accuracy: {correct / train_datanum:.4f}')
    train_loss = loss_sum/len(train_data_loader)
    train_acc = correct / train_datanum

    val_loss_sum = 0
    val_correct = 0
    model.eval()
    
    with torch.no_grad():

        for valid_x, valid_y in valid_data_loader:
            valid_x = valid_x.to('cuda')

            val_output = model(valid_x)
            val_loss = criterion(val_output.cpu(), valid_y)
            val_loss_sum += val_loss.item()
            
            val_predicted = torch.max(val_output, 1)[1]
            val_correct += (valid_y == val_predicted.cpu()).sum()
    
    print(f'Epoch [{epoch+1}/{num_epochs}] val_Loss: {val_loss_sum/len(valid_data_loader):.4f} val_Accuracy: {val_correct / valid_datanum:.4f}')

    validation_loss = val_loss_sum/len(valid_data_loader)
    validation_acc = val_correct / valid_datanum
    
    
    wandb.log({"train_acc":train_acc,"train_loss":train_loss,"validation_acc":validation_acc,"validation_loss":validation_loss})
        
    print('-----------------next-----------------')

--------------------------------------------------------
Epoch [1/100] Loss: 1.3291 Accuracy: 0.3764
Epoch [1/100] val_Loss: 1.3976 val_Accuracy: 0.2500
-----------------next-----------------
--------------------------------------------------------
Epoch [2/100] Loss: 1.1188 Accuracy: 0.4741
Epoch [2/100] val_Loss: 1.0274 val_Accuracy: 0.5625
-----------------next-----------------
--------------------------------------------------------
Epoch [3/100] Loss: 0.7855 Accuracy: 0.7069
Epoch [3/100] val_Loss: 0.7418 val_Accuracy: 0.7188
-----------------next-----------------
--------------------------------------------------------
Epoch [4/100] Loss: 0.6989 Accuracy: 0.7486
Epoch [4/100] val_Loss: 0.6743 val_Accuracy: 0.7500
-----------------next-----------------
--------------------------------------------------------
Epoch [5/100] Loss: 0.6113 Accuracy: 0.7730
Epoch [5/100] val_Loss: 0.6837 val_Accuracy: 0.7344
-----------------next-----------------
----------------------------------------

Epoch [43/100] val_Loss: 0.4712 val_Accuracy: 0.8594
-----------------next-----------------
--------------------------------------------------------
Epoch [44/100] Loss: 0.0709 Accuracy: 0.9763
Epoch [44/100] val_Loss: 0.5328 val_Accuracy: 0.8906
-----------------next-----------------
--------------------------------------------------------
Epoch [45/100] Loss: 0.0949 Accuracy: 0.9684
Epoch [45/100] val_Loss: 0.4274 val_Accuracy: 0.8594
-----------------next-----------------
--------------------------------------------------------
Epoch [46/100] Loss: 0.0886 Accuracy: 0.9713
Epoch [46/100] val_Loss: 0.4563 val_Accuracy: 0.8438
-----------------next-----------------
--------------------------------------------------------
Epoch [47/100] Loss: 0.0768 Accuracy: 0.9749
Epoch [47/100] val_Loss: 0.5086 val_Accuracy: 0.8281
-----------------next-----------------
--------------------------------------------------------
Epoch [48/100] Loss: 0.0778 Accuracy: 0.9763
Epoch [48/100] val_Loss: 0.712

Epoch [88/100] val_Loss: 0.3558 val_Accuracy: 0.9062
-----------------next-----------------
--------------------------------------------------------
Epoch [89/100] Loss: 0.0546 Accuracy: 0.9792
Epoch [89/100] val_Loss: 0.4733 val_Accuracy: 0.8594
-----------------next-----------------
--------------------------------------------------------
Epoch [90/100] Loss: 0.0345 Accuracy: 0.9864
Epoch [90/100] val_Loss: 0.4101 val_Accuracy: 0.8906
-----------------next-----------------
--------------------------------------------------------
Epoch [91/100] Loss: 0.0205 Accuracy: 0.9914
Epoch [91/100] val_Loss: 0.5733 val_Accuracy: 0.8438
-----------------next-----------------
--------------------------------------------------------
Epoch [92/100] Loss: 0.0290 Accuracy: 0.9892
Epoch [92/100] val_Loss: 0.9904 val_Accuracy: 0.8438
-----------------next-----------------
--------------------------------------------------------
Epoch [93/100] Loss: 0.0316 Accuracy: 0.9878
Epoch [93/100] val_Loss: 0.552