In [1]:
from PIL import Image
import numpy as np
import pandas as pd
from torch.utils.data import Dataset, DataLoader
import torch
import torchvision
from torchvision import transforms
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import utils
from tqdm import tqdm
import warnings
import torchvision.models as models

warnings.simplefilter('ignore', Image.DecompressionBombWarning)

In [2]:
lr = 5e-4
n_epochs = 100
training_path = "/DAS_Storage4/hyungseok/Training"
validation_path = "/DAS_Storage4/hyungseok/Validation"


In [3]:
device = torch.device("cuda:0")

In [4]:
train_trans = transforms.Compose([transforms.Resize((128,128)),
                           transforms.ToTensor(),                           
                           transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))
                           ])
valid_trans = transforms.Compose([transforms.Resize((128,128)),
                           transforms.ToTensor(),                           
                           transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))
                           ])
trainset = torchvision.datasets.ImageFolder(root = training_path,
                                           transform = train_trans)
validset = torchvision.datasets.ImageFolder(root = validation_path,
                                           transform = valid_trans)
train_loader = DataLoader(trainset, batch_size = 256, shuffle = True)
valid_loader = DataLoader(validset, batch_size = 256, shuffle = False)

In [5]:
model = models.resnet50(pretrained=False)
num_feature = model.fc.in_features
model.fc = nn.Linear(num_feature, 128)

In [6]:
model = model.to(device)
utils.init_weights(model, init_type='uniform')
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=lr, betas = (0.9, 0.98), eps = 1e-9, weight_decay = 1e-5)

In [None]:
for epoch in range(n_epochs):
    
    running_loss = 0.0
    train_correct = 0
    train_total = 0
    running_val_loss = 0.0
    valid_correct = 0
    valid_total = 0
    model.train()
    
    for idx, (x, y) in tqdm(enumerate(train_loader)):
        
        x, y = x.to(device), y.to(device)
        
        optimizer.zero_grad()
        
        output = model(x)
        
        loss = criterion(output, y)
        
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        
        _, predicted = torch.max(output.data, 1)
        train_total += y.size(0)
        train_correct += (predicted == y).sum().item()
        
    with torch.no_grad():
        model.eval()

        for idx, (x, y) in tqdm(enumerate(valid_loader)):

            x, y = x.to(device), y.to(device)

            valid_output = model(x)

            valid_loss = criterion(valid_output, y)

            running_val_loss += valid_loss.item()

            _, predicted = torch.max(valid_output.data, 1)
            valid_total += y.size(0)
            valid_correct += (predicted == y).sum().item()
            
    torch.save({
        'epoch' : epoch,
        'model_state_dict' : model.state_dict(),
        'optimizer_state_dict' : optimizer.state_dict(),
        'loss' : loss
    }, 'model/{}-new_model.pt'.format(epoch+1))


            
        
    print('Epoch {}/{}, Train_Acc: {:.3f}, Train_Loss : {:.6f}, valid_Acc : {:3f}, Valid_Loss : {:.6f}'.format(epoch+1,n_epochs, 
                                                                                                              train_correct/train_total,
                                                                                                              running_loss / len(train_loader),
                                                                                                             valid_correct/valid_total,
                                                                                                              valid_loss / len(valid_loader)
                                                                                                             ))
        

11it [02:34, 13.88s/it]