In [1]:
import numpy as np
import torch
import torch.nn as nn
import torchmetrics
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader, random_split
from loss import CenterLoss
import wandb
from tqdm import tqdm, trange
import matplotlib.pyplot as plt
import random

# GPU setting

In [2]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
torch.manual_seed(1)
if device == 'cuda':
    torch.cuda.manual_seed_all(1)

# Model

In [3]:
class CLModel(nn.Module):
    def __init__(self):
        super(CLModel, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 5, padding=2, stride=1, bias=False)
        self.conv1_1 = nn.Conv2d(32, 32, 5, padding=2, stride=1, bias=False)
        self.conv2 = nn.Conv2d(32, 64, 5, padding=2, stride=1, bias=False)
        self.conv2_1 = nn.Conv2d(64, 64, 5, padding=2, stride=1, bias=False)
        self.conv3 = nn.Conv2d(64, 128, 5, padding=2, stride=1, bias=False)
        self.conv3_1 = nn.Conv2d(128, 128, 5, padding=2, stride=1, bias=False)
        self.maxpool = nn.MaxPool2d(2, 2)
        self.prelu = nn.PReLU()
        self.fc1 = nn.Linear(1152, 2, bias=False)
        self.fc2 = nn.Linear(2, 10, bias=False)

    def forward(self, x):
        x = self.conv1(x)
        x = self.prelu(x)
        x = self.conv1_1(x)
        x = self.prelu(x)
        x = self.maxpool(x)

        x = self.conv2(x)
        x = self.prelu(x)
        x = self.conv2_1(x)
        x = self.prelu(x)
        x = self.maxpool(x)

        x = self.conv3(x)
        x = self.prelu(x)
        x = self.conv3_1(x)
        x = self.prelu(x)
        x = self.maxpool(x)

        x = x.view(x.size(0), -1)
        x = self.fc1(x)
        x = self.prelu(x)
        y = self.fc2(x)

        return x, y

# data preprocessing

In [4]:
transform = transforms.ToTensor()

train = torchvision.datasets.MNIST("../mnist", train=True, transform=transform, download=True)
train, valid = random_split(train, [0.9, 0.1])
test = torchvision.datasets.MNIST("../mnist", train=False, transform=transform, download=True)

train_loader = DataLoader(train, batch_size=128, shuffle=True, num_workers=8)
valid_loader = DataLoader(valid, batch_size=128, shuffle=False, num_workers=8)
test_loader = DataLoader(test, batch_size=128, shuffle=False, num_workers=8)

# train model

In [5]:
model = CLModel().to(device)
lr_l = 1e-3
lr_c = 0.5
total_epochs = 100
batch_size = 128
lamda = 0.001

In [6]:
criterion1 = nn.CrossEntropyLoss()
criterion2 = CenterLoss().to(device)
optimizer1 = torch.optim.SGD(model.parameters(), lr=lr_l, momentum=0.9, weight_decay=5e-4)
optimizer2 = torch.optim.SGD(criterion2.parameters(), lr=lr_c)

### train

In [None]:
for epoch in trange(total_epochs):
    avg_train_loss = 0
    avg_train_softmax_loss = 0
    avg_train_center_loss = 0
    avg_valid_loss = 0
    avg_valid_softmax_loss = 0
    avg_valid_center_loss = 0
    
    all_features, all_labels = [], []
    
    # train
    model.train()
    for x_tr, y_tr in train_loader:
        x_tr = x_tr.to(device)
        y_tr = y_tr.to(device)
        
        hypothesis = model(x_tr)
        coord, label = hypothesis
        optimizer1.zero_grad()
        optimizer2.zero_grad()
        train_softmax_loss = criterion1(label, y_tr)
        train_center_loss = criterion2(coord, y_tr)
        train_loss = train_softmax_loss + lamda*train_center_loss
        # train_loss = train_softmax_loss
        train_loss.backward()
        optimizer1.step()
        optimizer2.step()
        
        
        avg_train_loss += train_loss.item() / len(train_loader)
        avg_train_softmax_loss += train_softmax_loss.item() / len(train_loader)
        avg_train_center_loss += train_center_loss.item() / len(train_loader)
        
        # all_features.append(coord.detach().cpu().numpy())
        # all_labels.append(label.detach().cpu().numpy())
        all_features.append(coord.data.cpu().numpy())
        all_labels.append(label.data.cpu().numpy())
        
    all_features = np.concatenate(all_features, axis=0)
    all_labels = np.concatenate(all_labels, axis=0)
    colors = np.array(['C0', 'C1', 'C2', 'C3', 'C4', 'C5', 'C6', 'C7', 'C8', 'C9'])
    label_idx = all_labels.argmax(axis=1)
    for i in range(10):
        plt.scatter(all_features[label_idx == i, 0], all_features[label_idx == i, 1], color=colors[i], s=1)
    # plt.scatter(all_features[:, 0], all_features[:, 1], c=colors[label_idx], s=1)
    plt.legend(['0', '1', '2', '3', '4', '5', '6', '7', '8', '9'], loc='upper right')
    plt.show()
    
    all_features, all_labels = [], []
    
    # validation
    model.eval()
    for x_val, y_val in valid_loader:
        x_val = x_val.to(device)
        y_val = y_val.to(device)
        
        hypothesis = model(x_val)
        coord, label = hypothesis
        valid_softmax_loss = criterion1(label, y_val)
        valid_center_loss = criterion2(coord, y_val)
        valid_loss = valid_softmax_loss + lamda*valid_center_loss
        # valid_loss = valid_softmax_loss
        
        avg_valid_loss += valid_loss/len(valid_loader)
        avg_valid_softmax_loss += valid_softmax_loss.item() / len(valid_loader)
        avg_valid_center_loss += valid_center_loss.item() / len(valid_loader)
        
        all_features.append(coord.data.cpu().numpy())
        all_labels.append(label.data.cpu().numpy())
    
    # plot features
    all_features = np.concatenate(all_features, axis=0)
    all_labels = np.concatenate(all_labels, axis=0)
    colors = np.array(['C0', 'C1', 'C2', 'C3', 'C4', 'C5', 'C6', 'C7', 'C8', 'C9'])
    label_idx = all_labels.argmax(axis=1)
    for i in range(10):
        plt.scatter(all_features[label_idx == i, 0], all_features[label_idx == i, 1], color=colors[i], s=1)
    # plt.scatter(all_features[:, 0], all_features[:, 1], c=colors[label_idx], s=1)
    plt.legend(['0', '1', '2', '3', '4', '5', '6', '7', '8', '9'], loc='upper right')
    plt.show()
    
    
    print(f'Epoch: {epoch}, train_loss: {avg_train_loss}, valid_loss: {avg_valid_loss}')

plt.close('all')