In [2]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [52]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

import torchvision
from torchvision import datasets, transforms
from tqdm import tqdm

In [39]:
class Net1(nn.Module):
    def __init__(self):
        super(Net1, self).__init__()
        self.conv1 = nn.Conv2d(1, 20, 5, 1)
        self.conv2 = nn.Conv2d(20, 50, 5, 1)
        self.fc1 = nn.Linear(4 * 4 * 50, 500)
        self.fc2 = nn.Linear(500, 10)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, 2, 2)
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2, 2)
        x = x.view(-1, 4 * 4 * 50)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)

In [40]:
class Net3(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = torch.flatten(x, 1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return F.log_softmax(x, dim=1)

In [55]:
def train(model, mean_teacher, device, train_loader, test_loader, optimizer, epoch):
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()

        output = model(data)
        with torch.no_grad():
            mean_t_output = mean_teacher(data)

        const_loss = F.mse_loss(output, mean_t_output)

        weight = 0.2
        loss = F.nll_loss(output, target) + weight * const_loss
        loss.backward()
        optimizer.step()

        alpha = 0.95
        for mean_param, param in zip(mean_teacher.parameters(), model.parameters()):
            mean_param.data = mean_param.data * alpha + param.data * (1 - alpha)

        if batch_idx % 100 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                       100. * batch_idx / len(train_loader), loss.item()))
            test(model, device, test_loader)
            test(mean_teacher, device, test_loader)
            print()

In [42]:
def test(model, device, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += F.nll_loss(output, target, reduction='sum').item()
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)

    print(f'<Test set> Average loss: {test_loss:.4f}, Accuracy: {correct}/{len(test_loader.dataset)} \
            ({100. * correct / len(test_loader.dataset):.0f}%)')
    model.train()

In [43]:
epochs = 10
lr = 0.01
momentum = 0.5
random_seed = 42

In [44]:
torch.manual_seed(random_seed)
if torch.cuda.is_available():
    dev = "cuda"
else:
    dev = "cpu"
device = torch.device(dev)
print(f"DEVICD : {dev}")

DEVICD : cuda


In [45]:
transform_CIFAR10 = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

In [46]:
transform_MNIST = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.1307,), (0.3081,))])

In [47]:
def CIFAR10_dataset_on():
    trainset_CIFAR10 = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_CIFAR10)
    testset_CIFAR10 = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_CIFAR10)

    train_loader_CIFAR10 = DataLoader(trainset_CIFAR10, batch_size=4, shuffle=True, num_workers=2)
    test_loader_CIFAR10 = DataLoader(testset_CIFAR10, batch_size=4, shuffle=False, num_workers=2)
    return train_loader_CIFAR10, test_loader_CIFAR10

In [48]:
def MNIST_dataset_on():
    trainset_MNIST = datasets.MNIST(root='./data', train=True, download=True, transform=transform_MNIST)
    testset_MNIST = datasets.MNIST(root='./data', train=False, download=True, transform=transform_MNIST)

    train_loader_MNIST = DataLoader(trainset_MNIST, batch_size=1000, shuffle=True, num_workers=2)
    test_loader_MNIST = DataLoader(testset_MNIST, batch_size=1000, shuffle=False, num_workers=2)
    return train_loader_MNIST, test_loader_MNIST

In [58]:
# main
while (True):
    print("Dataset Choice: [0: MNIST, 1: CIFAR10]")
    select_num = int(input())

    if select_num == 0:
        train_loader, test_loader = MNIST_dataset_on()
        print("Dataset MNIST ON")
        model = Net1().to(device)
        mean_teacher = Net1().to(device)
        break

    elif select_num == 1:
        train_loader, test_loader = CIFAR10_dataset_on()
        print("Dataset CIFAR10 ON")
        model = Net3().to(device)
        mean_teacher = Net3().to(device)
        break
    else:
        exit()

optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum)

for epoch in range(1, epochs + 1):
    train(model, mean_teacher, device, train_loader, test_loader, optimizer, epoch)

Dataset Choice: [0: MNIST, 1: CIFAR10]
1
Files already downloaded and verified
Files already downloaded and verified
1
<Test set> Average loss: 2.3042, Accuracy: 1000/10000             (10%)
<Test set> Average loss: 2.3038, Accuracy: 1000/10000             (10%)

<Test set> Average loss: 2.3009, Accuracy: 1000/10000             (10%)
<Test set> Average loss: 2.3015, Accuracy: 1000/10000             (10%)



KeyboardInterrupt: ignored