In [None]:
import torch
import torch.nn as nn
from torch.autograd import Function
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import Dataset, DataLoader
import numpy as np
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import torch.optim.lr_scheduler as lr_scheduler
from mpl_toolkits.mplot3d import Axes3D
import os
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"

K = 1


In [None]:
class CenterLoss(nn.Module):
    def __init__(self, num_classes=K*10, feat_dim=2, use_gpu=True):
        super(CenterLoss, self).__init__()
        self.num_classes = num_classes
        self.feat_dim = feat_dim
        self.use_gpu = use_gpu

        if self.use_gpu:
            self.centers = nn.Parameter(torch.randn(self.num_classes, self.feat_dim).cuda())
        else:
            self.centers = nn.Parameter(torch.randn(self.num_classes, self.feat_dim))

    def forward(self, x, labels):
        # distamt(batch_size, num_classes)
        batch_size = x.size(0)
        distmat = torch.pow(x, 2).sum(dim=1, keepdim=True).expand(batch_size, self.num_classes) + \
                  torch.pow(self.centers, 2).sum(dim=1, keepdim=True).expand(self.num_classes, batch_size).t()
        distmat.addmm_(1, -2, x, self.centers.t())

        classes = torch.arange(self.num_classes).long()
        if self.use_gpu: classes = classes.cuda()
        labels = labels.unsqueeze(1).expand(batch_size, self.num_classes)
        mask = labels.eq(classes.expand(batch_size, self.num_classes))
        dist = distmat * mask.float()

        # dist = tau * distmat

        loss = dist.clamp(min=1e-12, max=1e+12).sum() / batch_size
        return loss

In [None]:
class CrossEntropyLoss(nn.Module):
    def __init__(self) -> None:
        super().__init__()

    def forward(self, outputs, targets):
        return CrossEntropyLossFunction.apply(outputs, targets)

class CrossEntropyLossFunction(Function):
    @staticmethod
    def forward(ctx, output, target):
        # output (64, 60)
        output = output - torch.max(output, dim=1, keepdim=True)[0].detach()
        output = F.softmax(output, dim=1)

        indices = ((target * K).view(-1, 1).cuda() + torch.arange(K).cuda()) % (K * 10)

        # tau (target)
        tau = torch.zeros_like(output)
        denum = torch.sum(torch.gather(output, 1, indices), dim=1) + 1e-12
        subclass_output = torch.gather(output, 1, indices)
        softmax_scores = subclass_output / denum[:, None]
        row_indices = torch.arange(indices.shape[0]).view(-1, 1)
        tau[row_indices, indices] = softmax_scores

        ctx.save_for_backward(output, tau)

        loss = torch.sum(-tau * output, dim=1).mean()
        return loss, tau

    @staticmethod
    def backward(ctx, grad_output, tau):
        # grad_output -> none, here we return dLoss/dy_hat
        output, tau = ctx.saved_tensors
        grad_input = (output - tau)/output.shape[0]  # grad_input -> dE_dwk
        return grad_input, None

In [None]:
class LinearFunction(Function):
    @staticmethod
    def forward(ctx, input, W_pos, W_neg, beta, target):
        d_X_WPos = W_pos - input.unsqueeze(1)  # (64, 60, 784)
        d_X_WNeg = W_neg - input.unsqueeze(1)  # (64, 60, 784)

        ctx.save_for_backward(input, d_X_WPos, d_X_WNeg, beta, target)

        d_pos_2 = 0.5 * torch.sum(d_X_WPos ** 2, dim=-1)  # (64, 60)
        d_neg_2 = 0.5 * torch.sum(d_X_WNeg ** 2, dim=-1)  # (64, 60)

        d_neg_2_minus_d_pos_2 = d_neg_2 - d_pos_2  # (64, 60)

        out = beta * d_neg_2_minus_d_pos_2

        return out

    @staticmethod
    def backward(ctx, grad_output):
        # grad_output: [batch_size, K*10], d(L)/d(out)
        input, d_X_WPos, d_X_WNeg, beta, target = ctx.saved_tensors
        grad_input = grad_W_pos = grad_W_neg = None

        # one hot labels
        indices = ((target * K).view(-1, 1).cuda() + torch.arange(K).cuda()) % (K * 10)
        split_grad_pos = torch.zeros_like(grad_output).scatter(1, indices, 1)
        split_grad_neg = torch.ones_like(grad_output) - split_grad_pos

        # gradient for beta
        d_pos_2 = 0.5 * torch.sum(d_X_WPos ** 2, dim=-1)
        d_neg_2 = 0.5 * torch.sum(d_X_WNeg ** 2, dim=-1)
        d_neg_2_minus_d_pos_2 = d_neg_2 - d_pos_2

        if ctx.needs_input_grad[0]:  # d(out)/d(X)
            grad_input = grad_output.unsqueeze(2) * (d_X_WPos - d_X_WNeg)
            grad_input = torch.sum(grad_input, dim=1)

        if ctx.needs_input_grad[1]: # d(out)/d(w_pos)
            grad_W_pos = - d_X_WPos * grad_output.unsqueeze(2) * split_grad_pos.unsqueeze(2)
            grad_W_pos = torch.sum(grad_W_pos, dim=0)

        if ctx.needs_input_grad[2]: # d(out)/d(w_neg)
            grad_W_neg = d_X_WNeg * grad_output.unsqueeze(2) * split_grad_neg.unsqueeze(2)
            grad_W_neg = torch.sum(grad_W_neg, dim=0)

        if ctx.needs_input_grad[3]: # d(out)/d(beta)
            grad_beta = grad_output * d_neg_2_minus_d_pos_2

        return grad_input, grad_W_pos, grad_W_neg, grad_beta, None

class MyLinearLayer(nn.Module):
    def __init__(self, input_size, output_size):
        super().__init__()
        self.input_size = input_size
        self.output_size = output_size
        self.W_pos = nn.Parameter(0.01 * (2 * torch.rand(output_size * K, input_size) - 1))
        self.W_neg = nn.Parameter(0.01 * (2 * torch.rand(output_size * K, input_size) - 1))
        self.beta = nn.Parameter(torch.tensor(1.0))

    def forward(self, x, target):
        return LinearFunction.apply(x, self.W_pos, self.W_neg, self.beta, target)


In [None]:

def visualize(feat, labels, epoch):
    plt.ion()
    c = ['#ff0000', '#ffff00', '#00ff00', '#00ffff', '#0000ff',
         '#ff00ff', '#990000', '#999900', '#009900', '#009999',
         '#ff00ff', '#990000', '#999900', '#009900', '#009999',
         '#ff00ff', '#990000', '#999900', '#009900', '#009999',
         '#ff00ff', '#990000', '#999900', '#009900', '#009999',
         '#ff00ff', '#990000', '#999900', '#009900', '#009999']
    plt.clf()
    for i in range(10*K):
        plt.plot(feat[labels == i, 0], feat[labels == i, 1], '.', c=c[i])
    plt.legend(['0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '10', '11',
                '12', '13', '14', '15', '16', '17', '18', '19', '20', '21',
                '22', '23', '24', '25', '26', '27', '28', '29'], loc = 'upper right')
    plt.xlim(xmin=-8,xmax=8)
    plt.ylim(ymin=-8,ymax=8)
    plt.text(-7.8,7.3,"epoch=%d" % epoch)
    plt.savefig('./images/epoch=%d.jpg' % epoch)
    plt.draw()
    plt.pause(0.001)


In [None]:
class ConvNet(nn.Module):
    #LeNet
    def __init__(self, num_classes):
        super(ConvNet, self).__init__()
        self.conv1_1 = nn.Conv2d(1, 32, 5, stride=1, padding=2)
        self.prelu1_1 = nn.PReLU()
        self.conv1_2 = nn.Conv2d(32, 32, 5, stride=1, padding=2)
        self.prelu1_2 = nn.PReLU()

        self.conv2_1 = nn.Conv2d(32, 64, 5, stride=1, padding=2)
        self.prelu2_1 = nn.PReLU()
        self.conv2_2 = nn.Conv2d(64, 64, 5, stride=1, padding=2)
        self.prelu2_2 = nn.PReLU()

        self.conv3_1 = nn.Conv2d(64, 128, 5, stride=1, padding=2)
        self.prelu3_1 = nn.PReLU()
        self.conv3_2 = nn.Conv2d(128, 128, 5, stride=1, padding=2)
        self.prelu3_2 = nn.PReLU()

        self.fc1 = nn.Linear(128*3*3, 2)
        self.prelu_fc1 = nn.PReLU()
        # self.fc2 = MyLinearLayer(2, num_classes)
        self.fc2 = nn.Linear(2, num_classes)

    def forward(self, x, target):
        x = self.prelu1_1(self.conv1_1(x))
        x = self.prelu1_2(self.conv1_2(x))
        x = F.max_pool2d(x, 2)

        x = self.prelu2_1(self.conv2_1(x))
        x = self.prelu2_2(self.conv2_2(x))
        x = F.max_pool2d(x, 2)

        x = self.prelu3_1(self.conv3_1(x))
        x = self.prelu3_2(self.conv3_2(x))
        x = F.max_pool2d(x, 2)

        x = x.view(-1, 128*3*3)
        x = self.prelu_fc1(self.fc1(x))
        y = self.fc2(x)

        return x, y

train_dataset = datasets.MNIST('./mnist', download=True, train=True,
                               transform=transforms.Compose([
                                       transforms.ToTensor(),
                                       transforms.Normalize((0.1307,), (0.3081,))]))
test_dataset = datasets.MNIST('./mnist', download=True, train=False,
                              transform=transforms.Compose([
                                      transforms.ToTensor(),
                                      transforms.Normalize((0.1307,), (0.3081,))]))

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=128, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=128, shuffle=False)

epoches = 100
class_num = 10
epoch_pre = 1
Center_step = 5
feat_dim = 2
weight = 0.1
alpha = 15

net1 = ConvNet(10)

criterion1 = nn.CrossEntropyLoss()
criterion2 = CenterLoss()

optimizer4nn = torch.optim.SGD(net1.parameters(), lr=0.005, momentum=0.9, weight_decay=0.0005)
optimizer_centloss = torch.optim.SGD(criterion2.parameters(), lr=0.5)

sheduler = lr_scheduler.StepLR(optimizer4nn, 20, gamma=0.5)

total_step = len(train_loader)

global best_acc
if torch.cuda.is_available():
    net1 = net1.cuda()


# based on K
C = 10
L = C * K
label_dict = np.zeros(L)
idx = 0
for c in range (C):
    for k in range (K):
        label_dict[idx] = c
        idx += 1
label_dict = torch.LongTensor(label_dict).cuda()

train_losses = []
train_accs = []
test_losses = []
test_accs = []
for epoch in range(1,epoches+epoch_pre):
    ip1_loader = []
    idx_loader = []

    net1 = net1.train()
    num_samples = len(train_loader.dataset)
    num_batches = len(train_loader)
    running_corrects = 0
    running_loss = 0.0
    for i, (im, label) in enumerate(train_loader):
        if torch.cuda.is_available():
            im = im.cuda()
            label = label.cuda()
        features, output = net1(im, label)

        loss_xent = criterion1(output, label)
        loss_cent = criterion2(features, label)

        loss =  loss_xent + loss_cent
        optimizer4nn.zero_grad()
        optimizer_centloss.zero_grad()

        loss.backward()

        optimizer4nn.step()
        optimizer_centloss.step()

        _, preds_subclass = torch.max(output, dim=1)
        preds = label_dict[preds_subclass]
        running_corrects += torch.sum(preds == label)
        running_loss += loss.item()

        ip1_loader.append(features)
        idx_loader.append((label))

        if (i+1)%100 == 0:
            print ('Epoch [{}/{}], Step [{}/{}], Loss {}, xent loss: {}, cent loss:{:.4f}'
                   .format(epoch, epoches, i+1, total_step, loss.item(), loss_xent.item(), loss_cent.item()))

    epoch_acc = (running_corrects / num_samples) * 100
    epoch_loss = (running_loss / num_batches)
    train_losses.append(epoch_loss)
    train_accs.append(epoch_acc.cpu().numpy())
    print(f"epoch {epoch} -> Loss: {epoch_loss}, accuracy: {epoch_acc}")

    feat = torch.cat(ip1_loader, 0)
    labels = torch.cat(idx_loader, 0)
    visualize(feat.data.cpu().numpy(),labels.data.cpu().numpy(),epoch)

    sheduler.step()

    print('testing')
    net1.eval()
    correct = 0
    C_correct = 0
    test_loss = 0.0

    total = 0
    with torch.no_grad():
        for ti, (images, t_label) in enumerate(test_loader):
            if torch.cuda.is_available():
                images = images.cuda()
                t_label = t_label.cuda()
            features, outputs = net1(images, t_label)

            _, predicted = torch.max(outputs.data, 1)
            preds = label_dict[predicted].cuda()
            correct += (preds == t_label).sum().item()
            total += t_label.size(0)

        test_loss = test_loss / len(test_loader.dataset)
        test_losses.append(test_loss)
        test_accuracy = 100.0 * (correct / total)
        test_accs.append(test_accuracy)
        print('Test Accuracy after {}th epoch on the 10000 test images: {} %'
              .format(epoch, test_accuracy))

# Plotting accuracy and loss during train
plt.figure(figsize=(10, 4))
plt.subplot(1, 2, 1)
plt.plot(train_losses, label='Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()

plt.subplot(1, 2, 2)
plt.plot(train_accs, label='Accuracy', color='orange')
plt.xlabel('Epoch')
plt.ylabel('Accuracy (%)')
plt.legend()

plt.tight_layout()
plt.show()


# Plotting the learning curve
epochsss=200
plt.figure(figsize=(16, 6))
plt.subplot(1, 2, 1)
plt.plot(train_losses, label='Training Loss')
plt.plot(test_losses, label='val Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Loss Curves')
plt.legend()
plt.grid()

#plotting accuracy
plt.subplot(1, 2, 2)
plt.plot(train_accs, label='Training Accuracy')
plt.plot(test_accs, label='val Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy (%)')
plt.title('Accuracy Curves')
plt.legend()
plt.grid()

plt.tight_layout()
plt.show()


Epoch [1/100], Step [100/469], Loss 2.304687976837158, xent loss: 2.3045570850372314, cent loss:0.0001
Epoch [1/100], Step [200/469], Loss 2.2574596405029297, xent loss: 2.2397027015686035, cent loss:0.0178
Epoch [1/100], Step [300/469], Loss 2.1190989017486572, xent loss: 2.0330231189727783, cent loss:0.0861
Epoch [1/100], Step [400/469], Loss 1.7896150350570679, xent loss: 1.6762152910232544, cent loss:0.1134
epoch 1 -> Loss: 2.172737587489553, accuracy: 27.918333053588867
testing
Test Accuracy after 1th epoch on the 10000 test images: 48.9 %
Epoch [2/100], Step [100/469], Loss 1.5140953063964844, xent loss: 1.160343885421753, cent loss:0.3538
Epoch [2/100], Step [200/469], Loss 1.3179188966751099, xent loss: 1.1386377811431885, cent loss:0.1793
Epoch [2/100], Step [300/469], Loss 1.1037542819976807, xent loss: 0.9161574840545654, cent loss:0.1876
Epoch [2/100], Step [400/469], Loss 1.0067496299743652, xent loss: 0.8282809257507324, cent loss:0.1785
epoch 2 -> Loss: 1.275485626543000

In [None]:
from PIL import Image
import os
import glob

def makeResultGif():
    folder_path = './images/'
    file_extension = '*.jpg'

    image_files = glob.glob(os.path.join(folder_path, file_extension))
    sorted_image_files = sorted(image_files, key=lambda x: os.path.getmtime(x))

    frames = []
    for image_file in sorted_image_files:
        image = Image.open(image_file)
        frames.append(image)

    output_gif_path = 'output.gif'

    frame_duration = 150

    frames[0].save(
        output_gif_path,
        save_all=True,
        append_images=frames[1:],
        duration=frame_duration,
        loop=0,
    )

makeResultGif()