In [1]:
import torch
import torch.nn as nn
from torch.nn.utils import weight_norm
import config
from temporal_ensembling import train_and_eval
from utils import Gaussian_Noise, save_exp_result
import datetime

class CNN_for_MNIST(nn.Module):
    def __init__(self, batch_size, std, p=0.5, fm1=16, fm2=32):
        super(CNN_for_MNIST, self).__init__()
        self.fm1   = fm1
        self.fm2   = fm2
        self.std   = std
        self.gn    = Gaussian_Noise(batch_size, (1, 28, 28), std=self.std)
        self.act   = nn.LeakyReLU()
        self.drop  = nn.Dropout(p)
        self.conv1 = weight_norm(nn.Conv2d(1, self.fm1, 3, padding=1)) # batch normaliztion
        self.conv2 = weight_norm(nn.Conv2d(self.fm1, self.fm2, 3, padding=1)) # batch normaliztion 
        self.mp    = nn.MaxPool2d(3, stride=2, padding=1)
        self.fc    = nn.Linear(self.fm2 * 7 * 7, 10)
    
    def forward(self, x):
        if self.training:
            x = self.gn(x)
        x = self.act(self.mp(self.conv1(x)))
        x = self.act(self.mp(self.conv2(x)))
        x = x.view(-1, self.fm2 * 7 * 7)
        x = self.drop(x)
        x = self.fc(x)
        return x


# MNIST 실험
accs = []
accs_best = []
losses = []
sup_losses = []
unsup_losses = []
idxs = []

st_for_exp = datetime.datetime.now().strftime('%Y_%m_%d_%H%M%S')
cfg = vars(config)

dataset = 'MNIST'
for i in range(cfg['n_exp']):
    model = CNN_for_MNIST(cfg['batch_size'], cfg['std'])
    seed = cfg['seeds'][i]
    acc, acc_best, l, sl, usl, total_labeled_idx = train_and_eval(model, seed, dataset='MNIST', **cfg)
    accs.append(acc)
    accs_best.append(acc_best)
    losses.append(l)
    sup_losses.append(sl)
    unsup_losses.append(usl)
    idxs.append(total_labeled_idx)

save_exp_result(st_for_exp, losses, sup_losses, unsup_losses, accs, accs_best, idxs, **cfg)

현재 epoch: 10, w(t): 0.009741
Epoch: 10 /200, Step: 300 /600, Loss: 0.051674
Epoch: 10 /200, Step: 600 /600, Loss: 0.038299, Elapsed time: 3.61 sec /epoch
현재 epoch: 20, w(t): 0.027319
Epoch: 20 /200, Step: 300 /600, Loss: 0.000393
Epoch: 20 /200, Step: 600 /600, Loss: 0.005436, Elapsed time: 3.54 sec /epoch
현재 epoch: 30, w(t): 0.065535
Epoch: 30 /200, Step: 300 /600, Loss: 0.020612
Epoch: 30 /200, Step: 600 /600, Loss: 0.020291, Elapsed time: 3.70 sec /epoch
현재 epoch: 40, w(t): 0.134468
Epoch: 40 /200, Step: 300 /600, Loss: 0.001063
Epoch: 40 /200, Step: 600 /600, Loss: 0.003765, Elapsed time: 3.43 sec /epoch
현재 epoch: 50, w(t): 0.235999
Epoch: 50 /200, Step: 300 /600, Loss: 0.001731
Epoch: 50 /200, Step: 600 /600, Loss: 0.001579, Elapsed time: 3.47 sec /epoch
현재 epoch: 60, w(t): 0.354276
Epoch: 60 /200, Step: 300 /600, Loss: 0.001955
Epoch: 60 /200, Step: 600 /600, Loss: 0.002316, Elapsed time: 3.48 sec /epoch
현재 epoch: 70, w(t): 0.454900
Epoch: 70 /200, Step: 300 /600, Loss: 0.001318
