In [1]:
from torchvision.datasets import MNIST, USPS, SVHN
import torch 
import torch.nn as nn
from torch.utils.data import Dataset,DataLoader
from torchvision.transforms import ToTensor,transforms,Compose,Resize,Lambda
import torch.nn.functional as fn
from model import CNN
import numpy as np

In [2]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
torch.manual_seed(777)
if device == 'cuda':
    torch.cuda.manual_seed_all(777)
device

'cuda'

In [3]:
# data download
trainM = MNIST(root='data',train=True,download=True,transform=Compose([
Resize(28),
ToTensor(),
Lambda(lambda x: x.repeat(3,1,1)),
]))
trainU = USPS(root='data',train=True,download=True,transform=Compose([
Resize(28),
ToTensor(),
Lambda(lambda x: x.repeat(3,1,1)),
]))
trainS = SVHN(root='data',split='train',download=True,transform=Compose([
Resize(28),
ToTensor(),
]))
testM = MNIST(root='data',train=False,download=True,transform=Compose([
Resize(28),
ToTensor(),
Lambda(lambda x: x.repeat(3,1,1)),
]))
testU = USPS(root='data',train=False,download=True,transform=Compose([
Resize(28),
ToTensor(),
Lambda(lambda x: x.repeat(3,1,1)),
]))
testS = SVHN(root='data',split='test',download=True,transform=Compose([
Resize(28),
ToTensor(),
]))

Using downloaded and verified file: data/train_32x32.mat
Using downloaded and verified file: data/test_32x32.mat


In [4]:
def makeDataset(raw_data,batch_size):
    data_loader = DataLoader(dataset=raw_data,
                batch_size=batch_size,
                shuffle=True,
                drop_last=True) # batch 마지막 데이터 버림
    return data_loader

train_loader_M = makeDataset(trainM,16)
train_loader_U = makeDataset(trainU,16)
train_loader_S = makeDataset(trainS,16)

test_loader_M = makeDataset(testM,16)
test_loader_U = makeDataset(testU,16)
test_loader_S = makeDataset(testS,16)

In [5]:
train_loaders = (train_loader_M,train_loader_U,train_loader_S)
test_loaders = (test_loader_M,test_loader_U,test_loader_S)

names = ['MNIST','USPS','SVHN']

In [16]:
# train
for c in range(3):
    shape = (16,3,28,28)
    learning_rate = 0.001
    epoch = 20
    batches = [3,5,200] # num of data = batch*16
    
    model = CNN(shape)
    model = model.to(device)
    
    model.train()

    criterion = torch.nn.CrossEntropyLoss().to(device)
    optimizer = torch.optim.Adam(model.parameters(),lr = learning_rate)

    for e in range(epoch):
        avg_cost = 0
        loader = train_loaders[c] # 매 epoch마다 loader 생성
        for i,data in enumerate(loader):
            x,y = data
            x=x.to(device)
            y=y.to(device)
            optimizer.zero_grad()
            predict = model(x)
            cost = criterion(predict,y)
            cost.backward()
            optimizer.step()

            avg_cost += cost / batches[c]
            if i==batches[c]:
                break
        if (e+1)%5 == 0:
            print('[Epoch: {:>4}] cost = {:>.9}'.format(e + 1, avg_cost))
    torch.save(model.state_dict(),"./weight/base/"+names[c]+".pkl")
    

[Epoch:    5] cost = 2.1793766
[Epoch:   10] cost = 1.06350899
[Epoch:   15] cost = 0.842508912
[Epoch:   20] cost = 0.724280357
[Epoch:    5] cost = 1.13988841
[Epoch:   10] cost = 0.646030605
[Epoch:   15] cost = 0.520708382
[Epoch:   20] cost = 0.46530968
[Epoch:    5] cost = 1.08881831
[Epoch:   10] cost = 0.634732783
[Epoch:   15] cost = 0.553443193
[Epoch:   20] cost = 0.52062887


In [17]:
# 입력
# 세가지 모델에 대해 입력을 받음
# 세가지 데이터셋에 대해 cross check 진행

# evaluation


for m in range(3) : # 세가지 모델
    shape = (16,3,28,28)
    model = CNN(shape)
    model.load_state_dict(torch.load("./weight/base/"+names[m]+".pkl"))
    model = model.to(device)
    model.eval()
    print('model : ',names[m])
    for c in range(3): # 세가지 데이터셋
        loader = test_loaders[c]
        correct=0
        total=0

        #load model
        with torch.no_grad():
            for x,y in loader:
                x=x.to(device)
                y=y.to(device)
                predict = torch.max(model(x).data,1)[1]
                total+=len(y)
                correct+=(predict==y).sum().item()
        print(names[c]+' Test Accuracy: ',100.*correct/total, '%')

model :  MNIST
MNIST Test Accuracy:  83.73 %
USPS Test Accuracy:  55.15 %
SVHN Test Accuracy:  9.480639213275968 %
model :  USPS
MNIST Test Accuracy:  43.89 %
USPS Test Accuracy:  87.0 %
SVHN Test Accuracy:  12.2579901659496 %
model :  SVHN
MNIST Test Accuracy:  57.65 %
USPS Test Accuracy:  55.55 %
SVHN Test Accuracy:  83.27059004302397 %
