# Consistency regularization Tutorial

- 실습조교: 배진수(wlstn215@korea.ac.kr), 안시후(sihuahn@korea.ac.kr), 김현지(99ktxx@korea.ac.kr)

## Colab gpu 연결

### 런타임 -> 런타임유형 변경 -> 하드웨어 가속기 -> GPU

In [None]:
import torch
torch.cuda.is_available()

## 0.모듈 불러오기

In [None]:
''' github+colab 교육생분들 '''
# !git clone https://github.com/bogus215/LG-EDUCATION3.git

''' 기본 모듈 및 시각화 모듈 '''
import random
import time
import numpy as np
import easydict # dictionary의 속성을 dot(.)을 사용하여 표기가능
from tqdm.auto import tqdm

'''Neural Network을 위한 딥러닝 모듈'''
from torch import nn
from torch import optim
from torchvision import datasets
from torch.utils.data import Dataset
from torchvision import transforms
import torch.nn.functional as F
from torch.utils.data import DataLoader

'''Custom 모듈 (WIdeResNet, RandAugment)'''
''' github+colab 교육생분들 '''
# import sys
# sys.path.append("./LG-EDUCATION3")
from models import Wide_ResNet, initialize_weights
from augmentation import *

## 1. 준지도학습 분석 환경 세팅

In [None]:
# 하이퍼파라미터 셋팅
args = easydict.EasyDict({
    
    # base configuration for learning
    "seed": 0,
    "gpu": 0,
    "lambda_u" : 20,
    "total_epoch" : 10,
    
    # for data
    "data_path" : "./data",
    "num_labeled" : 10000,
    "num_unlabeled" : 40000,
    "num_classes" : 10,
    "size" : 32,
    "batch_size" : 256,
    
    # for WideResNet model
    "depth" : 10, 
    "widen_factor" : 2, 
    "dropout" : 0.1,
    
    # for optimizing
    "lr" : 0.02, # train learning rate of teacher model
    "weight_decay" : 2e-4, # train weight decay
    
})

In [None]:
# 랜덤 시드 고정
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
torch.cuda.manual_seed_all(args.seed)

In [None]:
# GPU 할당하기
args.device = torch.device('cuda', args.gpu) if torch.cuda.is_available() else torch.device('cpu')

## 2.데이터셋 & 데이터로더 준비 (Labeled & Unlabeled)

In [None]:
# CIFAR 10 데이터 다운로드
base_dataset = datasets.CIFAR10(args.data_path, train=True, download=True)
test_dataset = datasets.CIFAR10(args.data_path, train=False, download=True)

In [None]:
# 클래스 별 데이터 개수
labeled_per_class = args.num_labeled // args.num_classes
unlabeled_per_class = args.num_unlabeled // args.num_classes 

print("클래스 1개당 Labeled 데이터 개수: ",labeled_per_class)
print("클래스 1개당 Unlabeled 데이터 개수: ",unlabeled_per_class)

# FULL DATA (Labeled & Unlabeled)
images, labels = np.array(base_dataset.data), np.array(base_dataset.targets)

labeled_idx, unlabeled_idx = [], []
for i in range(args.num_classes):
    class_idx = np.where(labels==i)[0]
    labeled_idx.extend(class_idx[:labeled_per_class])
    unlabeled_idx.extend(class_idx[labeled_per_class:(labeled_per_class+unlabeled_per_class)])
   
# Labeled 데이터 총 갯수 확인하기
print("Labeled 데이터 총 갯수: ",len(labeled_idx))

# Unlabeled 데이터 총 갯수 확인하기
print("Unabeled 데이터 총 갯수: ",len(unlabeled_idx))

labeled_images, unlabeled_images = images[labeled_idx], images[unlabeled_idx]
labeled_targets, unlabeled_targets = labels[labeled_idx], labels[unlabeled_idx]

In [None]:
# train_valid split
from sklearn.model_selection import train_test_split
train_labeled_images, valid_labeled_images, train_labeled_targets, valid_labeled_targets = train_test_split(labeled_images, labeled_targets, test_size=.2, stratify=labeled_targets)

print(train_labeled_images.shape)
print(valid_labeled_images.shape)
print(train_labeled_targets.shape)
print(valid_labeled_targets.shape)

In [None]:
# RandAugment에서 사용할 augmentation 리스트 생성
# Augmentation별 강도 조절 및 환경과 관련된 하이퍼파라미터 설정 필요
cifar10_aug_list = [
                    (Color, 4, 0.5),
                    (Contrast, 4, 0.5),
                    (Brightness, 4, 0.5),
                    (Sharpness, 4, 0.5),
                    ]

# RandAugment 클래스 만들기
class RandAugmentCIFAR(object):
    def __init__(self, n, m, aug_list):

        self.n = int(n)
        self.m = m
        self.augment_pool = aug_list

    def __call__(self, img):
        ops = random.choices(self.augment_pool, k=self.n)
        for op, max_v, bias in ops:
            img = op(img, v=self.m, max_v=max_v, bias=bias)
        return img

In [None]:
# 정규화에 사용될 평균, 표준편차
cifar10_mean = [0.49139968, 0.48215841, 0.44653091]
cifar10_std = [0.24703223, 0.24348513, 0.26158784]

# print(base_dataset.data.mean(axis=(0,1,2))/255)
# print(base_dataset.data.std(axis=(0,1,2))/255)

# Labeled 데이터셋을 위한 데이터변환
transform_labeled = transforms.Compose([
            transforms.ToPILImage(),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.RandomCrop(size=args.size, padding=4),
            transforms.ToTensor(),
            transforms.Normalize(mean=cifar10_mean, std=cifar10_std)
            ])

# Test 데이터셋을 위한 데이터변환
transform_test = transforms.Compose([
            transforms.ToPILImage(),
            transforms.ToTensor(),
            transforms.Normalize(mean=cifar10_mean, std=cifar10_std)
            ])

# Unlabeled 데이터셋을 위한 데이터변환
class CustomTransform(object):

    def __init__(self, args, n, m, mean, std, aug_list):
        self.n, self.m = n, m
        
        self.weak_aug = transforms.Compose([transforms.ToPILImage(),
                                            transforms.RandomHorizontalFlip(),
                                            transforms.RandomCrop(size=args.size, padding=4)
                                            ])
        
        self.strong_aug = transforms.Compose([transforms.ToPILImage(),
                                              RandAugmentCIFAR(n=n, m=m, aug_list=aug_list)])
        
        self.normalize = transforms.Compose([transforms.ToTensor(),
                                             transforms.Normalize(mean=mean, std=std)])
        
    def __call__(self, x):
        weak_augmented = self.weak_aug(x)
        strong_augmented = self.strong_aug(x)

        return self.normalize(weak_augmented), self.normalize(strong_augmented)
    
transform_unlabeled = CustomTransform(args, n=1, m=1, mean=cifar10_mean, std=cifar10_std, aug_list=cifar10_aug_list)

In [None]:
# 파이토치 데이터셋 구축하기
class CIFAR10SSL(Dataset):
    def __init__(self, X, Y, transform):
        
        self.x_data = X
        self.targets = Y
        self.transform = transform

    def __getitem__(self, index):

        img, target = self.x_data[index], self.targets[index]

        if self.transform is not None:
            img = self.transform(img)

        return img, target
    
    def __len__(self):
        return len(self.targets)

In [None]:
# 데이터셋 및 데이터 로더 구축 완료
train_labeled_dataset = CIFAR10SSL(X=train_labeled_images, Y = train_labeled_targets, transform=transform_labeled)
valid_labeled_dataset = CIFAR10SSL(X=valid_labeled_images, Y = valid_labeled_targets, transform=transform_test)

unlabeled_dataset = CIFAR10SSL(X=unlabeled_images, Y=unlabeled_targets, transform=transform_unlabeled)
test_dataset = CIFAR10SSL(X=np.array(test_dataset.data),Y=np.array(test_dataset.targets),transform=transform_test)

train_labeled_loader = DataLoader(train_labeled_dataset, batch_size=args.batch_size, shuffle=True, drop_last=True)
valid_labeled_loader = DataLoader(valid_labeled_dataset, batch_size=args.batch_size, shuffle=True, drop_last=True)

unlabeled_loader = DataLoader(unlabeled_dataset, batch_size=args.batch_size, shuffle=True, drop_last=True)
test_loader = DataLoader(test_dataset, batch_size=args.batch_size)

## 3.Semi-supervised learning (Consistency regularization)

In [None]:
model = Wide_ResNet(num_classes=args.num_classes,depth=args.depth,widen_factor=args.widen_factor,dropout_rate=args.dropout)
model.to(args.device)

In [None]:
# Loss function & optimizer 설정
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(params=model.parameters(), lr=args.lr, momentum=0.9, nesterov=True, weight_decay=args.weight_decay)

In [None]:
since = time.time()

train_loss = []
valid_loss = []
best_loss = np.inf

# 모델은 training mode로 설정
for epoch in (pbar :=tqdm(range(args.total_epoch))):

    corrects = 0
    total = 0
    epoch_loss = 0
    model.train()
    unlabeled_iter = iter(unlabeled_loader)
    for ind, (label_inputs, label_targets) in enumerate(train_labeled_loader):

        label_inputs = label_inputs.to(args.device)
        label_targets = label_targets.to(args.device)

        try:
            (weak_inputs, strong_inputs), _ = next(unlabeled_iter)
            weak_inputs = weak_inputs.to(args.device)
            strong_inputs = strong_inputs.to(args.device)
        except:
            unlabeled_iter = iter(unlabeled_loader)
            (weak_inputs, strong_inputs), _ = next(unlabeled_iter)
            weak_inputs = weak_inputs.to(args.device)        
            strong_inputs = strong_inputs.to(args.device)
            
        # parameter gradients를 0으로 설정
        optimizer.zero_grad()
        
        # forward
        full_logits = model(torch.cat([label_inputs,weak_inputs, strong_inputs],axis=0))
        label_logits, weak_unlabeled_logits, strong_unlabeled_logits = full_logits.split(label_targets.size(0))
        label_loss = criterion(label_logits, label_targets)

        unlabeled_loss = F.mse_loss(input=strong_unlabeled_logits.softmax(-1), target=weak_unlabeled_logits.softmax(-1).detach())
        tot_loss = args.lambda_u*unlabeled_loss + label_loss
        
        # backward
        tot_loss.backward()
        optimizer.step()
        
        # batch별 loss를 축적함
        epoch_loss += tot_loss.item()
        
        # output 중 최대값의 위치에 해당하는 class로 예측 수행
        preds = label_logits.argmax(dim=-1)
        
        # batch별 정답 개수를 축적함
        corrects += torch.sum(preds == label_targets).item()
        total += label_targets.size(0)        

        pbar.set_description(f'Loss : {tot_loss.item():.2f}  |  Accuracy: {torch.sum(preds == label_targets).item()/label_targets.size(0):.2f}')
        
    # epoch의 loss 도출
    epoch_loss /= len(train_labeled_loader)
    train_loss.append(epoch_loss)
    pbar.set_description(f'{epoch+1} Loss : {(epoch_loss)/(ind+1):.2f}  |  Accuracy : {(corrects/total):.2f}')
    

    corrects = 0
    total = 0
    epoch_loss = 0
    model.eval()    
    with torch.no_grad():
        for ind, (inputs, targets) in enumerate(valid_labeled_loader):
            inputs = inputs.to(args.device)
            targets = targets.to(args.device)

            # forward
            outputs = model(inputs)
            loss = criterion(outputs, targets)

            # batch별 loss를 축적함
            epoch_loss += loss.item()

            # output 중 최대값의 위치에 해당하는 class로 예측 수행
            preds = outputs.argmax(dim=-1)

            # batch별 정답 개수를 축적함
            corrects += torch.sum(preds == targets).item()
            total += targets.size(0)        

            pbar.set_description(f'Loss : {loss.item():.2f}  |  Accuracy: {torch.sum(preds == targets).item()/targets.size(0):.2f}')

    # epoch의 loss 도출
    epoch_loss /= len(valid_labeled_loader)
    valid_loss.append(epoch_loss)
    pbar.set_description(f'{epoch+1} Loss : {(epoch_loss)/(ind+1):.2f}  |  Accuracy : {(corrects/total):.2f}')
    
    if best_loss > valid_loss[-1]:
        print(best_loss)
        best_loss = valid_loss[-1]
        torch.save(model.state_dict(), "CA-1.pth")
    
time_elapsed = time.time() - since
print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))

In [None]:
model.load_state_dict(torch.load("CA-1.pth"))
model.eval()
corrects = 0
total = 0
with torch.no_grad():
    for inputs, targets in tqdm(test_loader):
        inputs = inputs.to(args.device)
        targets = targets.to(args.device)
        
        # forward
        outputs = model(inputs)
        
        # output 중 최대값의 위치에 해당하는 class로 예측 수행
        preds = outputs.argmax(-1)
        
        # batch별 정답 개수를 축적함
        corrects += torch.sum(preds == targets).item()
        total += targets.size(0)

test_acc = corrects / total
print('Testing Acc: {:.4f}'.format(test_acc))

## 4.Semi-supervised learning (Consistency regularization + Sharpening)

In [None]:
initialize_weights(model)

In [None]:
# Loss function & optimizer 설정
optimizer = optim.SGD(params=model.parameters(), lr=args.lr, momentum=0.9, nesterov=True, weight_decay=args.weight_decay)

In [None]:
since = time.time()

train_loss = []
valid_loss = []
best_loss = np.inf

# 모델은 training mode로 설정
for epoch in (pbar :=tqdm(range(args.total_epoch))):

    corrects = 0
    total = 0
    epoch_loss = 0
    model.train()
    unlabeled_iter = iter(unlabeled_loader)
    for ind, (label_inputs, label_targets) in enumerate(train_labeled_loader):

        label_inputs = label_inputs.to(args.device)
        label_targets = label_targets.to(args.device)

        try:
            (weak_inputs, strong_inputs), _ = next(unlabeled_iter)
            weak_inputs = weak_inputs.to(args.device)
            strong_inputs = strong_inputs.to(args.device)
        except:
            unlabeled_iter = iter(unlabeled_loader)
            (weak_inputs, strong_inputs), _ = next(unlabeled_iter)
            weak_inputs = weak_inputs.to(args.device)        
            strong_inputs = strong_inputs.to(args.device)
            
        # parameter gradients를 0으로 설정
        optimizer.zero_grad()
        
        # forward
        full_logits = model(torch.cat([label_inputs,weak_inputs, strong_inputs],axis=0))
        label_logits, weak_unlabeled_logits, strong_unlabeled_logits = full_logits.split(label_targets.size(0))
        label_loss = criterion(label_logits, label_targets)

        unlabeled_loss = F.mse_loss(input=strong_unlabeled_logits.softmax(-1), target=((weak_unlabeled_logits.softmax(-1))**2 / (((weak_unlabeled_logits.softmax(-1))**2).sum(-1).reshape(-1,1))).detach())
        tot_loss = args.lambda_u*unlabeled_loss + label_loss
        
        # backward
        tot_loss.backward()
        optimizer.step()
        
        # batch별 loss를 축적함
        epoch_loss += tot_loss.item()
        
        # output 중 최대값의 위치에 해당하는 class로 예측 수행
        preds = label_logits.argmax(dim=-1)
        
        # batch별 정답 개수를 축적함
        corrects += torch.sum(preds == label_targets).item()
        total += label_targets.size(0)        

        pbar.set_description(f'Loss : {tot_loss.item():.2f}  |  Accuracy: {torch.sum(preds == label_targets).item()/label_targets.size(0):.2f}')
        
    # epoch의 loss 도출
    epoch_loss /= len(train_labeled_loader)
    train_loss.append(epoch_loss)
    pbar.set_description(f'{epoch+1} Loss : {(epoch_loss)/(ind+1):.2f}  |  Accuracy : {(corrects/total):.2f}')
    

    corrects = 0
    total = 0
    epoch_loss = 0
    model.eval()    
    with torch.no_grad():
        for ind, (inputs, targets) in enumerate(valid_labeled_loader):
            inputs = inputs.to(args.device)
            targets = targets.to(args.device)

            # forward
            outputs = model(inputs)
            loss = criterion(outputs, targets)

            # batch별 loss를 축적함
            epoch_loss += loss.item()

            # output 중 최대값의 위치에 해당하는 class로 예측 수행
            preds = outputs.argmax(dim=-1)

            # batch별 정답 개수를 축적함
            corrects += torch.sum(preds == targets).item()
            total += targets.size(0)        

            pbar.set_description(f'Loss : {loss.item():.2f}  |  Accuracy: {torch.sum(preds == targets).item()/targets.size(0):.2f}')

    # epoch의 loss 도출
    epoch_loss /= len(valid_labeled_loader)
    valid_loss.append(epoch_loss)
    pbar.set_description(f'{epoch+1} Loss : {(epoch_loss)/(ind+1):.2f}  |  Accuracy : {(corrects/total):.2f}')
    
    if best_loss > valid_loss[-1]:
        print(best_loss)
        best_loss = valid_loss[-1]
        torch.save(model.state_dict(), "CA-2.pth")
    
time_elapsed = time.time() - since
print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))

In [None]:
model.load_state_dict(torch.load("CA-2.pth"))
model.eval()
corrects = 0
total = 0
with torch.no_grad():
    for inputs, targets in tqdm(test_loader):
        inputs = inputs.to(args.device)
        targets = targets.to(args.device)
        
        # forward
        outputs = model(inputs)
        
        # output 중 최대값의 위치에 해당하는 class로 예측 수행
        preds = outputs.argmax(-1)
        
        # batch별 정답 개수를 축적함
        corrects += torch.sum(preds == targets).item()
        total += targets.size(0)

test_acc = corrects / total
print('Testing Acc: {:.4f}'.format(test_acc))

In [None]:
import matplotlib.pyplot as plt
_, axs = plt.subplots(2,1,figsize=(10,5))

axs[0].bar(x=np.arange(10),height=weak_unlabeled_logits.softmax(-1)[0].detach().cpu().numpy(), alpha=0.5)
axs[0].set_title("before")
axs[0].set_ylim(0,1)


target=((weak_unlabeled_logits.softmax(-1))**2 / (((weak_unlabeled_logits.softmax(-1))**2).sum(-1).reshape(-1,1))).detach()
axs[1].bar(x=np.arange(10),height=target[0].cpu().numpy(), alpha=.5)
axs[1].set_title("after")
axs[1].set_ylim(0,1)
plt.show()