# 大作业要求

1.大作业以2-3人为一组完成，提交材料包括PPT（最后一次课将会展示课程成果）+ 最终的大作业报告（需组内各成员单独提交，内容为本人在课程大作业中的贡献以及对大作业问题的思考) + 提交包含分工情况及组内各成员工作量占比的表格。分工表格需组内所有成员签字确认；

2.禁止抄袭，发现雷同，所有雷同提交分数除以2；

3.写清楚大作业中的贡献和创新点，若使用开源代码和论文中的方法，在报告中必须注明（不可作为本人创新点），发现不标注引用，分数除以2。

最后一次课展示说明：
1.样例
PPT例子：https://www.sohu.com/a/166633625_642762
2.展示时间限制：
展示时间为最后一节课，展示时间为6分钟讲+2分钟同学助教老师自由提问

大作业报告：强调个人对问题的理解，以及贡献，建议增加在提问反馈之后的改进结果。

最终评分为:30%展示评分+70%大作业报告

# 问题描述(Out-of-Distribution)

深度神经网络通常采用独立同分布(Independent-Identically)的假设进行训练，即假设测试数据分布与训练数据分布相似。然而，当用于实际任务时，这一假设并不成立，导致其性能显著下降。虽然这种性能下降对于产品推荐等宽容性大的应用是可以接受的，但在医学等宽容性小的领域使用此类系统是危险的，因为它们可能导致严重事故。理想的人工智能系统应尽可能在分布外（Out-of-Distribution）的情况下有较强的分部外泛化能力。而提高分布外泛化的关键点，就是如何让模型学习到数据中的causal feature。  
一个简单的例子：以猫狗二分类为例，如果训练集中所有狗都在草地上，所有的猫都在沙发上，而测试集中所有的狗在沙发上，所有的猫在草地上，那么模型在没有测试集信息的情况下，很有可能根据训练集的信息把草地和狗联系在了一起，沙发和猫联系在了一起，当模型在测试集上测试时将会把在沙发上的狗误认为是猫。

# 数据集(Colored MNIST)

Colored MNIST是一个分布外泛化领域中常用的数据集，在该数据集中，训练集和测试集之间存在Out-of-Distribution情况，color feature和数字产生了spurious correlation，即虚假的因果关系。从直观上来说，数字的形状为causal feature，数字的颜色为non-causal feature。该次大作业旨在探索如何让模型学习到causal feature来提高泛化能力。

In [None]:
import os

import numpy as np
from PIL import Image

import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import grad
from torchvision import transforms
from torchvision import datasets
import torchvision.datasets.utils as dataset_utils
from torch.utils.data import DataLoader
from Lenet import LeNet
from Lenet import LeNet_NoBN

%matplotlib inline
plt.rcParams['figure.figsize'] = (10, 8)
plt.rcParams['image.interpolation'] = 'nearest'
plt.rcParams['image.cmap'] = 'gray'

In [None]:
def color_grayscale_arr(arr, red=True):
    """Converts grayscale image to either red or green"""
    assert arr.ndim == 2
    dtype = arr.dtype
    h, w = arr.shape
    arr = np.reshape(arr, [h, w, 1])
    if red:
        arr = np.concatenate([arr,
                              np.zeros((h, w, 2), dtype=dtype)], axis=2)
    else:
        arr = np.concatenate([np.zeros((h, w, 1), dtype=dtype),
                              arr,
                              np.zeros((h, w, 1), dtype=dtype)], axis=2)
    return arr


class ColoredMNIST(datasets.VisionDataset):
    """
  Colored MNIST dataset for testing IRM. Prepared using procedure from https://arxiv.org/pdf/1907.02893.pdf

  Args:
    root (string): Root directory of dataset where ``ColoredMNIST/*.pt`` will exist.
    env (string): Which environment to load. Must be 1 of 'train1', 'train2', 'test', or 'all_train'.
    transform (callable, optional): A function/transform that  takes in an PIL image
      and returns a transformed version. E.g, ``transforms.RandomCrop``
    target_transform (callable, optional): A function/transform that takes in the
      target and transforms it.
  """

    def __init__(self, root='./data', env='train1', transform=None, target_transform=None):
        super(ColoredMNIST, self).__init__(root, transform=transform,
                                           target_transform=target_transform)

        self.prepare_colored_mnist()
        if env in ['train1', 'train2', 'test']:
            self.data_label_tuples = torch.load(os.path.join(self.root, 'ColoredMNIST', env) + '.pt')
        elif env == 'all_train':
            self.data_label_tuples = torch.load(os.path.join(self.root, 'ColoredMNIST', 'train1.pt')) + \
                                     torch.load(os.path.join(self.root, 'ColoredMNIST', 'train2.pt'))
        else:
            raise RuntimeError(f'{env} env unknown. Valid envs are train1, train2, test, and all_train')

    def __getitem__(self, index):
        """
    Args:
        index (int): Index

    Returns:
        tuple: (image, target) where target is index of the target class.
    """
        img, target = self.data_label_tuples[index]

        if self.transform is not None:
            img = self.transform(img)

        if self.target_transform is not None:
            target = self.target_transform(target)

        return img, target

    def __len__(self):
        return len(self.data_label_tuples)

    def get_label(self):
        return np.array(self.data_label_tuples)[:, 1]

    def prepare_colored_mnist(self):
        colored_mnist_dir = os.path.join(self.root, 'ColoredMNIST')
        if os.path.exists(os.path.join(colored_mnist_dir, 'train1.pt')) \
                and os.path.exists(os.path.join(colored_mnist_dir, 'train2.pt')) \
                and os.path.exists(os.path.join(colored_mnist_dir, 'test.pt')):
            print('Colored MNIST dataset already exists')
            return

        print('Preparing Colored MNIST')
        train_mnist = datasets.mnist.MNIST(self.root, train=True, download=True)

        train1_set = []
        train2_set = []
        test_set = []
        for idx, (im, label) in enumerate(train_mnist):
            if idx % 10000 == 0:
                print(f'Converting image {idx}/{len(train_mnist)}')
            im_array = np.array(im)

            # Assign a binary label y to the image based on the digit
            binary_label = 0 if label < 5 else 1

            # Flip label with 25% probability
            if np.random.uniform() < 0.25:
                binary_label = binary_label ^ 1

            # Color the image either red or green according to its possibly flipped label
            color_red = binary_label == 0

            # Flip the color with a probability e that depends on the environment
            if idx < 20000:
                # 20% in the first training environment
                if np.random.uniform() < 0.2:
                    color_red = not color_red
            elif idx < 40000:
                # 10% in the first training environment
                if np.random.uniform() < 0.1:
                    color_red = not color_red
            else:
                # 90% in the test environment
                if np.random.uniform() < 0.9:
                    color_red = not color_red

            colored_arr = color_grayscale_arr(im_array, red=color_red)

            if idx < 20000:
                train1_set.append((Image.fromarray(colored_arr), binary_label))
            elif idx < 40000:
                train2_set.append((Image.fromarray(colored_arr), binary_label))
            else:
                test_set.append((Image.fromarray(colored_arr), binary_label))

        if not os.path.exists(colored_mnist_dir):
            os.makedirs(colored_mnist_dir)
        torch.save(train1_set, os.path.join(colored_mnist_dir, 'train1.pt'))
        torch.save(train2_set, os.path.join(colored_mnist_dir, 'train2.pt'))
        torch.save(test_set, os.path.join(colored_mnist_dir, 'test.pt'))

# 初级部分：数据预处理

1.在Colored MNIST上训练和测试LeNet。  
2.在数据读取过程中增加数据预处理的方式（数据增广等），提高OOD泛化能力 【若选做高级部分可以跳过】

【OOD算法性能评价准则：在训练过程中只能接触训练集，不能在测试集上进行调参或者模型选择】

In [None]:
device = torch.device('cuda')
torch.manual_seed(1212)

def get_gaussian_noise(img):
    mean = 0
    std = 2
    noise = np.random.normal(mean, std, img.shape)
    noisy_img = np.clip(img+noise,0,255)
    return noisy_img

def get_colored_pixel(img):
    tmp = np.nonzero(img)
    x = img[:, tmp[0][1], tmp[0][2]].reshape(3, 1, 1)
    x = np.broadcast_to(x, (3, 28, 28))
    return torch.tensor(x)

def get_mean(img):
    x = torch.mean(img, axis=0)
    x = torch.tensor(np.broadcast_to(x, (3,28,28)))
    return x

transform = transforms.Compose([
    # transforms.CenterCrop(7), 
    # transforms.RandomHorizontalFlip(p=0.5), 
    transforms.ToTensor(), 
    # transforms.Lambda(lambda img: get_mean(img)), 
    # transforms.Lambda(lambda img: get_gaussian_noise(img)), 
    transforms.Normalize((0.5, ), (0.5, )), 
    transforms.Resize([28, 28])
])
train_data = ColoredMNIST(env='all_train', transform=transform)
test_data = ColoredMNIST(env='test', transform=transforms.ToTensor())
num_train = 35000
train_loader = DataLoader(train_data, batch_size=128, 
                          sampler=torch.utils.data.SubsetRandomSampler(range(num_train)))
val_loader = DataLoader(train_data, batch_size=128, 
                        sampler=torch.utils.data.SubsetRandomSampler(range(num_train, 40000)))
test_loader = DataLoader(test_data, batch_size=2000, shuffle=True)
# print(train_data[0][0].shape)
# input shape is 3*28*28
train1 = ColoredMNIST(env='train1', transform=transform)
train2 = ColoredMNIST(env='train2', transform=transform)
dtype = train1[0][0].dtype

In [None]:
# Visualization of part of the dataset
plt.figure(figsize=(5,16))
for i, ds in enumerate(['train1', 'train2', 'test_data']):
    for j in range(2):
        label = eval(ds).get_label()
        index = np.flatnonzero(label == j)
        index = np.random.choice(index, 20, replace=False)
        for k in range(5):
            if k==1:
                if j==0:
                    plt.title('red')
                if j==1:
                    plt.title('green')
            plt.subplot(20, 8, k*8 + j + 3*i + 1)
            plt.imshow(eval(ds)[index[k]][0].permute(1, 2, 0))
            plt.axis('off')
    
    for p in range(20):
        plt.subplot(20, 8, 8*k + 3*i)
        plt.axis('off')
    plt.figtext(0.22 + i * 0.3, 0.91, s=ds, ha='center', fontsize=12, fontweight='bold')
plt.show()

plt.savefig('std5.png')


In [None]:
# dataloader for train1 and train2
num_train12 = 20000
train1_loader = DataLoader(train1, batch_size=20000, shuffle=False)
train2_loader = DataLoader(train2, batch_size=20000, shuffle=False)

In [None]:
def get_acc(scores, y):
    y_pred = scores > 0
    return (y_pred == y).float().mean()

def data_split(data, device, dtype=torch.float32):
    x1, y1 = data[0]
    x2, y2 = data[1]
    x3, y3 = data[2]
    x1 = x1.to(device,dtype)
    x2 = x2.to(device,dtype)
    x3 = x3.to(device, dtype)
    y1 = y1.to(device,dtype).unsqueeze(1)
    y2 = y2.to(device,dtype).unsqueeze(1)
    y3 = y3.to(device,dtype).unsqueeze(1)
    return x1, y1, x2, y2, x3, y3

In [None]:
# Load data to GPU
data = list(zip(train1_loader, train2_loader, test_loader))[0]
x1, y1, x2, y2, x3, y3 = data_split(data, device)

In [None]:
train_acc_history_erm = {}
test_acc_history_erm = {}
dummy = torch.nn.Parameter(torch.tensor(1.0), requires_grad=True).to(device)
environments=['train1', 'train2']
def train_IRM(model, optimizer, epoch, dummy, 
              penalty_anneal_iters=20000, lamb=1e4):
    model.train()
    x = torch.cat((x1, x2), dim=0)
    y = torch.cat((y1, y2), dim=0)
    for e in range(epoch):
        scores = model(x)
        scores3 = model(x3)
        train_acc = get_acc(scores, y)
        test_acc = get_acc(scores3, y3)
        loss = F.binary_cross_entropy_with_logits(scores, y)
        total_loss = loss 
        train_acc_history_erm[e] = train_acc.detach().cpu()
        test_acc_history_erm[e] = test_acc.detach().cpu()
        if e == penalty_anneal_iters:
            optimizer = optim.Adam(model.parameters(), lr=1e-3)
        total_loss.backward()
        optimizer.step()
        optimizer.zero_grad()


        print("Epoch {},total_loss: {:3f}, tran_acc: {:3f}, test_acc: {:3f}"
            .format(e, loss, train_acc, test_acc))
        if e%10 == 0:
            print('------------------------')
    

# trianing time!
learning_rate = 4e-3
model11 = LeNet(out_channels=1).to(device=device)
optimizer = optim.Adam(model11.parameters(), lr=learning_rate)
train_IRM(model11, optimizer, epoch=300, dummy=dummy)

In [None]:
plt.figure(figsize=(7,5))
train_curve, = plt.plot(list(train_acc_history_erm.values()))
test_curve, = plt.plot(list(test_acc_history_erm.values()))
plt.xlabel('Epoch')
plt.ylabel('Acc')
plt.legend(handles=[train_curve, test_curve], 
           labels=['train_acc', 'test_acc'], loc='best')

# 中级部分：算法复现

https://github.com/facebookresearch/DomainBed  
1.复现Invariant Risk Minimization (IRM)算法。  
2.从以下论文中选择一个OOD算法复现，思考什么样的算法可以在此数据集上取得较好效果。  
    - Domain-Adversarial Training of Neural Networks (DANN)    
    - Out-of-Distribution Generalization via Risk Extrapolation (VREx)  
    - Learning Explanations that are Hard to Vary (AndMask)  
    - Self-Challenging Improves Cross-Domain Generalization (RSC)   
3.IRM算法对penalty weight参数较为敏感，如何改进，提高IRM算法稳定性。

### IRM复现

In [None]:
def compute_penalty(scores, y, dummy):
    loss = F.binary_cross_entropy_with_logits(scores * dummy, y)
    g = grad(loss, dummy, create_graph=True)[0]
    return (g**2).sum()

In [None]:
def train_IRM(model, optimizer, epoch, dummy, 
              penalty_anneal_iters=50, lamb=1e3, alpha=0.9):
    model.train()
    last_trian_acc = 0.5
    penalty_weight = lamb
    current_weight = lamb
    beta = 1
    gamma = 1
    train_acc_history_irm = {}
    test_acc_history_irm = {}
    for e in range(epoch):
        scores1 = model(x1)
        scores2 = model(x2)
        scores3 = model(x3)
        train_acc = (get_acc(scores1, y1) + get_acc(scores2, y2)) / 2
        test_acc = get_acc(scores3, y3)
        penalty = (compute_penalty(scores1, y1, dummy) \
                + compute_penalty(scores2, y2, dummy)) / 2
        loss = (F.binary_cross_entropy_with_logits(scores1, y1)\
                + F.binary_cross_entropy_with_logits(scores2, y2)) / 2
        
        penalty_weight = lamb if e>penalty_anneal_iters else 1.0
        total_loss = loss + penalty_weight * penalty
        train_acc_history_irm[e] = train_acc.detach().cpu()
        test_acc_history_irm[e] = test_acc.detach().cpu()
        if e == penalty_anneal_iters:
            optimizer = optim.Adam(model.parameters(), lr=1e-3, weight_decay=7e-4)
        if penalty_weight > 1:
            total_loss /= penalty_weight
        total_loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        print("Epoch {},total_loss: {:3f}, tran_acc: {:3f}, test_acc: {:3f}, "
              "penalty_weight:{:3f}, current_weight:{:3f}, beta:{:3f}"
               .format(e, penalty_weight * penalty, train_acc, test_acc, penalty_weight,current_weight,beta))
        

    return train_acc_history_irm, test_acc_history_irm

# trianing time!
learning_rate = 3e-3
model = LeNet(out_channels=1).to(device=device)
optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=7e-4)
train_acc_history_irm, test_acc_history_irm = train_IRM(model, optimizer, epoch=200)

In [None]:
plt.figure(figsize=(7,5))
plt.subplot(2,1,1)
train_curve, = plt.plot(list(train_acc_history_irm.values()))
test_curve, = plt.plot(list(test_acc_history_irm.values()))
plt.xlabel('Epoch')
plt.ylabel('Acc')
plt.legend(handles=[train_curve, test_curve], 
           labels=['train_acc', 'test_acc'], loc='best')
plt.subplot(2,1,2)
train_curve, = plt.plot(list(train_acc_history_irm.values())[100:])
test_curve, = plt.plot(list(test_acc_history_irm.values())[100:])
plt.xlabel('Epoch')
plt.ylabel('Acc')
plt.legend(handles=[train_curve, test_curve], 
           labels=['train_acc', 'test_acc'], loc='best')
plt.tight_layout()
plt.show()

### VREx复现

In [None]:
train_acc_history_VREx = {}
test_acc_history_VREx = {}
def train_VERx(model, optimizer, epoch, 
              penalty_anneal_iters=100, lamb=1e4):
    model.train()
    for e in range(epoch + 1):
        scores1 = model(x1)
        scores2 = model(x2)
        scores3 = model(x3)
        train_acc = (get_acc(scores1, y1) + get_acc(scores2, y2)) / 2
        test_acc = (get_acc(scores3, y3))
        loss1 = F.binary_cross_entropy_with_logits(scores1, y1)
        loss2 = F.binary_cross_entropy_with_logits(scores2, y2)
        penalty = (loss1 - loss2) ** 2
        loss = (loss1 + loss2)
        
        penalty_weight = lamb if e>penalty_anneal_iters else 1.0
        total_loss = loss + penalty_weight * penalty
        if penalty_weight > 1:
            total_loss /= penalty_weight

        if e == penalty_anneal_iters:
            optimizer = optim.Adam(model.parameters(), lr=1e-4, weight_decay=7e-4)
        train_acc_history_VREx[e] = train_acc.detach().cpu()
        test_acc_history_VREx[e] = test_acc.detach().cpu()
        total_loss.backward()
        optimizer.step()
        optimizer.zero_grad()
                
        print("Epoch {},total_loss: {:3f}, tran_acc: {:3f}, test_acc: {:3f}"
            .format(e, penalty_weight * penalty, train_acc, test_acc))

# trianing time!
learning_rate = 3e-3
model = LeNet(out_channels=1).to(device=device)
optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=7e-4)
train_VERx(model, optimizer, epoch=200)

In [None]:
plt.figure(figsize=(10,5))
plt.subplot(2,1,1)
train_curve, = plt.plot(list(train_acc_history_VREx.values()))
test_curve, = plt.plot(list(test_acc_history_VREx.values()))
plt.xlabel('Epoch')
plt.ylabel('Acc')
plt.legend(handles=[train_curve, test_curve], 
           labels=['train_acc', 'test_acc'], loc='best')
plt.subplot(2,1,2)
train_curve, = plt.plot(list(train_acc_history_VREx.values())[100:])
test_curve, = plt.plot(list(test_acc_history_VREx.values())[100:])
plt.xlabel('Epoch')
plt.ylabel('Acc')
plt.legend(handles=[train_curve, test_curve], 
           labels=['train_acc', 'test_acc'], loc='best')
plt.tight_layout()
plt.show()

### Penalty Fluctuation

In [None]:
dummy = torch.nn.Parameter(torch.tensor(1.0), requires_grad=True).to(device)
def train_IRM_fluc(model, optimizer, epoch, dummy, 
              penalty_anneal_iters=50, lamb=1e3, alpha=0.9):
    train_acc_history_irm = {}
    test_acc_history_irm = {}
    model.train()
    last_trian_acc = 0.5
    penalty_weight = lamb
    current_weight = lamb
    beta = 1
    gamma = 1
    for e in range(epoch):
        scores1 = model(x1)
        scores2 = model(x2)
        scores3 = model(x3)
        train_acc = (get_acc(scores1, y1) + get_acc(scores2, y2)) / 2
        test_acc = get_acc(scores3, y3)
        penalty = (compute_penalty(scores1, y1, dummy) \
                + compute_penalty(scores2, y2, dummy)) / 2
        loss = (F.binary_cross_entropy_with_logits(scores1, y1)\
                + F.binary_cross_entropy_with_logits(scores2, y2)) / 2
        # Penalty Fluctuation
        if e<penalty_anneal_iters:
            penalty_weight = 1
        elif e==penalty_anneal_iters:
            last_trian_acc = train_acc
            penalty_weight = lamb
        else:
            beta = last_trian_acc / train_acc
            current_weight = lamb * 1e4**(-torch.tanh(beta**2 - 1))
            penalty_weight = alpha * penalty_weight + (1-alpha) * current_weight
            last_trian_acc = train_acc
        
        total_loss = loss + penalty_weight * penalty
        train_acc_history_irm[e] = train_acc.detach().cpu()
        test_acc_history_irm[e] = test_acc.detach().cpu()
        if e == penalty_anneal_iters:
            optimizer = optim.Adam(model.parameters(), lr=1e-3, weight_decay=7e-4)
        if penalty_weight > 1:
            total_loss /= penalty_weight
        total_loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        print("Epoch {},total_loss: {:3f}, tran_acc: {:3f}, test_acc: {:3f},"
              "penalty_weight:{:3f}, current_weight:{:3f}, beta:{:3f}"
          .format(e, penalty_weight * penalty, train_acc, test_acc, penalty_weight,current_weight,beta))

    return train_acc_history_erm, test_acc_history_irm

In [None]:
lambs = np.arange(1e5, 1e6 + 1, 1e5)
converged_acc_nofluc = []
converged_acc_fluc = []
for lamb in lambs:
    model1 = LeNet_NoBN(out_channels=1).to(device=device)
    model2 = LeNet_NoBN(out_channels=1).to(device=device)
    learning_rate = 4e-3
    optimizer1 = optim.Adam(model1.parameters(), lr=learning_rate, weight_decay=7e-4)
    optimizer2 = optim.Adam(model2.parameters(), lr=learning_rate, weight_decay=7e-4)
    _, test1 = train_IRM(model1, optimizer1, epoch=100, dummy=dummy, lamb=lamb)
    acc = list(test1.values())[97]
    converged_acc_nofluc.append(acc)
    print("lamb:{:e}, finish training model1:{:3f}".format(lamb, acc))
    _, test2 = train_IRM_fluc(model2, optimizer2, epoch=100, dummy=dummy, lamb=lamb)
    acc = list(test2.values())[98]
    converged_acc_fluc.append(acc)
    print("lamb:{:e}, finish training model2: {}".format(lamb, acc))