In [22]:
import torch
from model import ResNet18
from torch.optim import SGD
import matplotlib.pyplot as plt 
import numpy as np
import torch.nn as nn
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torch.utils.data
from copy import deepcopy
import torch.nn.functional as F
from sklearn.metrics.pairwise import cosine_similarity

# 参数设置




In [23]:
setting = {}
setting['epochs'] = 350
setting['lr'] = 0.1
setting['use_gpu'] = False
setting['lr_shedule'] =  [350 //3, 700//3 ]
setting['lr_decay'] = 0.1
setting['batch_size'] = 128
setting['dataset'] = 'CIAFR10'

# 导入数据


In [24]:
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
train_loader = torch.utils.data.DataLoader(
        datasets.CIFAR10(root='./data/', train=True, transform=transforms.Compose([
            transforms.RandomHorizontalFlip(),
            transforms.RandomCrop(32, 4),
            transforms.ToTensor(),
            normalize,
        ]), download=True),
        batch_size=setting['batch_size'], shuffle=True,
        num_workers=8, pin_memory=True)
val_loader = torch.utils.data.DataLoader(
        datasets.CIFAR10(root='./data/', train=True, transform=transforms.Compose([
            transforms.ToTensor(),
            normalize,
        ])),
        batch_size=setting['batch_size'], shuffle=False,
        num_workers=8, pin_memory=True)
test_loader = torch.utils.data.DataLoader(
        datasets.CIFAR10(root='./data/', train=False, transform=transforms.Compose([
            transforms.ToTensor(),
            normalize,
        ])),
        batch_size=setting['batch_size'], shuffle=False,
        num_workers=8, pin_memory=True)


Files already downloaded and verified


# 导入模型

In [25]:
model = ResNet18()
if setting['use_gpu'] == True:
    model = model.cuda()

# 优化器

In [26]:
optimizer = SGD(model.parameters(), lr=setting['lr'], momentum=0.9, weight_decay=5e-4)
if setting['use_gpu']:
    device = 'cuda'
else:
    device = 'cpu'

# 计算 Test Acc


In [27]:
def evaluate_batch(model, data_loader, device):
	model.eval()
	correct = num = correct_t5 =0
	for iter, pack in enumerate(data_loader):
		data, target = pack[0].to(device), pack[1].to(device)
		logits = model(data)
		_, pred = logits.max(1)
		_, pred_t5 = torch.topk(logits, 5, dim=1)
		correct += pred.eq(target).sum().item()
		correct_t5 += pred_t5.eq(torch.unsqueeze(target, 1).repeat(1, 5)).sum().item()
		num += data.shape[0]
	print('Correct : ', correct)
	print('Num : ', num)
	print('Test ACC : ', correct / num)
	print('Top 5 ACC : ', correct_t5 / num)
	torch.cuda.empty_cache()
	model.train()
	return correct / num

# 学习率调整

In [28]:
def descent_lr(epoch, optimizer, lr, lr_decay, epoch_schedule):
        index = 0
        for k in epoch_schedule:
            if epoch > k:
                index += 1
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr * lr_decay ** index
        print('***********************************')
        print('epoch:', epoch)
        print('learning rate:', param_group['lr'])
        print('***********************************')

# Embedding 处理

In [29]:
def norm(x):
    return np.linalg.norm(x, ord=2, axis=1, keepdims=False)


def neural_collapse_embedding(model, data_loader, device):
    model.eval()
    embedding = []
    label = []
    for iter, pack in enumerate(data_loader):
        data, target = pack[0].to(device), pack[1].to(device)
        embed = model.forward_embedding(data)
        embedding_arr = embed.detach().cpu().numpy()
        embedding.append(embedding_arr)
        label_arr = target.cpu().numpy()
        label.append(label_arr)
    embedding_np = np.concatenate(embedding, 0)
    print( embedding_np.shape)
    label_np = np.concatenate(label, 0)
    class_embedding = []
    intra_variation = []
    class_mean = []
    global_mean = np.mean(embedding_np, 0, keepdims=True)
    class_weights = np.zeros((10, 512))
    corr_intra = []
    corr_inter = []
    for k in range(10):
        tmp_index = [i for i in range(len(label_np)) if int(label_np[i]) ==k]
        class_embedding.append(embedding_np[tmp_index])
        class_mean.append(np.mean(embedding_np[tmp_index], 0))
        class_weights[k] = class_mean[k]
        corr_intra.append(np.matmul((embedding_np[tmp_index] - global_mean).transpose(), (embedding_np[tmp_index] - global_mean) ))
        corr_inter.append(np.matmul((np.mean(embedding_np[tmp_index], 0, keepdims=True) - global_mean).transpose(), (np.mean(embedding_np[tmp_index], 0, keepdims=True) - global_mean) ))
    corr_intra = np.mean(np.array(corr_intra), 0)
    corr_inter = np.mean(np.array(corr_inter), 0)
    intra_v = np.matrix.trace(np.matmul(corr_intra, np.linalg.inv(corr_inter))) /10.0
    equal_norm_activation = np.std(norm(np.array(class_mean) - global_mean)) / np.mean(norm(np.array(class_mean) - global_mean))
    class_mean_matrix = np.array(class_mean) - global_mean
    cosine_sim = cosine_similarity(class_mean_matrix, class_mean_matrix).ravel()
    cosine_sim = np.array([i for i in cosine_sim if int(i)<1])
    equa_ang = np.std(cosine_sim)
    equa_ang_2 = np.mean(np.abs(cosine_sim + 1/9.0))
    model.train()
    return equal_norm_activation, class_weights/ np.linalg.norm(class_weights, ord='fro'), equa_ang, equa_ang_2, intra_v


# Weight 处理

In [30]:
def weight_feature(weight):
    weight_norm = np.linalg.norm(weight, ord=2, axis=1, keepdims=False).ravel()
    equinorm_weight = np.std(weight_norm) / np.mean(weight_norm)
    normalized_weight =  weight / np.linalg.norm(weight, ord='fro')
    weight_n = weight - np.mean(weight, 0)
    cosine_sim = cosine_similarity(weight_n, weight_n).ravel()
    cosine_sim = np.array([i for i in cosine_sim if int(i)<1])
    equa_ang_w = np.std(cosine_sim)
    equa_ang_w_2 = np.mean(np.abs(cosine_sim + 1/9.0))
    return equinorm_weight, normalized_weight, equa_ang_w, equa_ang_w_2

# 训练过程

In [None]:
train_acc = []
test_acc = []
eq_norm = []
eq_norm_w = []
eq_ang = []
eq_ang_w = []
eq_ang_max = []
eq_ang_w_max = []
classifier_collapse = []
variation_collapse = []
criterion  = nn.CrossEntropyLoss()
for ep in range(1, setting['epochs']+1):
    ### train
    model.train()
    descent_lr(ep, optimizer, setting['lr'], setting['lr_decay'], setting['lr_shedule'])
    loss_val = 0
    correct = num = 0
    for iter, pack in enumerate(train_loader):
        data, target = pack[0].to(device), pack[1].to(device)
        logits = model(data)
        loss = criterion(logits, target)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        _, pred = logits.max(1)
        loss_val += loss.item()
        correct += pred.eq(target).sum().item()
        num += data.shape[0]
        if (iter + 1) % 50 == 0:
            print('*******************************')
            print('epoch : ', ep )
            print('iteration : ', iter + 1)
            print('loss : ', loss_val/100)
            print('Correct : ', correct)
            print('Num : ', num)
            print('Train ACC : ', correct/num)
            correct = num = 0
            loss_val = 0
    ### test
    with torch.no_grad():
        print('Val ACC')
        train_acc_e = evaluate_batch(model, val_loader, device)
        print('Test ACC')
        test_acc_e = evaluate_batch(model, test_loader, device)
        train_acc.append(train_acc_e)
        test_acc.append(test_acc_e)
        equal_norm, clas_weight, equal_ang, equal_ang_2, intra_var = neural_collapse_embedding(model, val_loader, device)
  #      print(model.state_dict()['linear.weight'].size())
        tmp_weight =  deepcopy(model.state_dict()['linear.weight'].cpu().detach().numpy())
 #       print(tmp_weight)
        equinorm_weight, n_weight, equal_ang_w, equal_ang_w_2 = weight_feature(tmp_weight)
        del tmp_weight
        eq_norm.append(equal_norm)
        eq_norm_w.append(equinorm_weight)
        eq_ang.append(equal_ang)
        eq_ang_w.append(equal_ang_w)
        eq_ang_max.append(equal_ang_2)
        eq_ang_w_max.append(equal_ang_w_2)
        variation_collapse.append(intra_var)
        classifier_collapse.append(np.linalg.norm(clas_weight-n_weight, ord='fro'))


***********************************
epoch: 1
learning rate: 0.1
***********************************


# 可视化

## zero training error epoch

In [None]:
zero_train_error_epoch_list = [i for i in range(len(train_acc)) if train_acc[i] > 0.999]
zero_train_error_epoch_list.sort()
zero_train_error_epoch = zero_train_error_epoch_list[0]
print('Zero Training Error')
print(zero_train_error_epoch)

## Training Acc

In [None]:
plt.figure()
plt.title('Train ACC')
plt.plot(range(1, len(train_acc)+1), train_acc)
plt.vlines(zero_train_error_epoch, 0, 1, colors = "r", linestyles = "dashed")
plt.savefig('train_acc.png')
plt.close()

## Testing Acc

In [None]:
plt.figure()
plt.title('Test ACC')
plt.plot(range(1, len(test_acc)+1), test_acc)
plt.vlines(zero_train_error_epoch, 0, 1, colors = "r", linestyles = "dashed")
plt.savefig('test_acc.png')
plt.close()

## Equal Norm

In [None]:
plt.figure()
plt.title('Equal Norm')
plt.plot(range(1, len(eq_norm)+1), eq_norm,label='Activation')
plt.plot(range(1, len(eq_norm)+1), eq_norm_w,label='Weight')
plt.legend()
plt.vlines(zero_train_error_epoch, 0, 1, colors = "r", linestyles = "dashed", )
plt.savefig('equalnorm.png')
plt.close()

## Equiangularity

In [None]:
plt.figure()
plt.title('Equal Ang')
plt.plot(range(1, len(eq_norm)+1), eq_ang,label='Activation')
plt.plot(range(1, len(eq_norm)+1), eq_ang_w,label='Weight')
plt.legend()
plt.vlines(zero_train_error_epoch, 0, 1, colors = "r", linestyles = "dashed", )
plt.savefig('equalang.png')
plt.close()

plt.figure()
plt.title('Equal Ang MAx')
plt.plot(range(1, len(eq_norm)+1), eq_ang_max,label='Activation')
plt.plot(range(1, len(eq_norm)+1), eq_ang_w_max,label='Weight')
plt.legend()
plt.vlines(zero_train_error_epoch, 0, 1, colors = "r", linestyles = "dashed", )
plt.savefig('equalang_max.png')


## within-class variation collapses

In [None]:
plt.figure()
plt.title('Intra Class  Collapse')
plt.plot(range(1, len(eq_norm)+1), variation_collapse)

plt.vlines(zero_train_error_epoch, 0, 1, colors = "r", linestyles = "dashed", )
plt.savefig('variationcollapse.png')
plt.close()

## Classifier converges to train class-means

In [None]:
plt.figure()
plt.title('Weight Collapse')
plt.plot(range(1, len(eq_norm)+1), classifier_collapse)

plt.vlines(zero_train_error_epoch, 0, 1, colors = "r", linestyles = "dashed", )
plt.savefig('weightcollapse.png')
plt.close()