# Freeze

Objective: Enc -> Dec -> Enc 를 탄다고 했을 때, 두번째 enc 만 freeze 해서 gradient 가 안 가게 하려면?  
크게 3가지 방법이 있음:

1. Cloning: forward by copied enc
2. Manual grad: calc grad using `torch.autograd.grad` function instead of `backward()`
3. Inference-only freezing: Freeze enc in forward

In [1]:
import os
import re
import numpy as np
import h5py as h5
import copy
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.utils as tv_utils
import numpy as np
from PIL import Image
import time

In [2]:
class Encoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(3, 10),
        )

    def forward(self, x):
        x = self.net(x)
        return x
    
class Decoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(10, 3)
        )
        
    def forward(self, x):
        x = self.net(x)
        return x

In [3]:
def zero_grad():
    enc.zero_grad()
    dec.zero_grad()

In [4]:
def compute_loss(r):
    return r.sum()

In [5]:
def grad_stats(model):
    grads = []
    for p in model.parameters():
        grads.append(p.grad.flatten())
    
    g = torch.cat(grads)
    mean = g.mean()
    std = g.std()

    return mean.item(), std.item()

In [6]:
def baseline(x):
    """Do not freeze second encoder"""
    x = enc(x)
    x = dec(x)
    x = enc(x)
    
    loss = compute_loss(x)
    loss.backward()

    return grad_stats(enc), grad_stats(dec)

In [7]:
def cloning(x):
    """Clone the encoder and freeze it"""
    x = enc(x)
    x = dec(x)
    x = freeze_enc(x)
    
    loss = compute_loss(x)
    loss.backward()

    return grad_stats(enc), grad_stats(dec)

In [8]:
def freezing(x):
    """Inference-only freezing"""
    x = enc(x)
    x = dec(x)
    
    for p in enc.parameters():
        p.requires_grad_(False)
        
    x = enc(x)
    
    for p in enc.parameters():
        p.requires_grad_(True)
    
    loss = compute_loss(x)
    loss.backward()
    
    return grad_stats(enc), grad_stats(dec)

In [9]:
# def two_step_backward(params1, loss, params2, feat, retain_graph=False):
#     """
#     module := params1 + params2
#     loss.backward(params1)  # exclude params2
#     feat.backward(params2, feat.grad)  # feat has grads from step 1
#     """
#     params1 = list(params1)
#     params2 = list(params2)
    
#     grads1 = torch.autograd.grad(loss, params1 + [feat], retain_graph=retain_graph)
#     grads1, feat_grad = grads1[:-1], grads1[-1]
#     grads2 = torch.autograd.grad(feat, params2, feat_grad, retain_graph=retain_graph)
    
#     with torch.no_grad():
#         for p, g in zip(params1 + params2, grads1 + grads2):
#             p.grad = g.detach()

def two_step_backward(params1, feat, params2, loss, **kwargs):
    """torch.autograd.grad 를 이용해서 decoder 와 first encoder grad 직접 계산하여 넣어줌.
    
    Args:
        params1: encoder params
        feat: encoded features
        params2: decoder params
        loss: loss
        kwargs: kwargs for backward() function: `retain_graph` and `create_graph`
    """
    params1 = list(params1)
    params2 = list(params2)
    
    grads2 = torch.autograd.grad(loss, [feat] + params2, **kwargs)
    feat_grad, grads2 = grads2[0], grads2[1:]
    grads1 = torch.autograd.grad(feat, params1, feat_grad, **kwargs)
    
    with torch.no_grad():
        for p, g in zip(params1 + params2, grads1 + grads2):
            p.grad = g.detach()

def grading(x):
    feat = enc(x)
    fake = dec(feat)
    r = enc(fake)
    
    loss = compute_loss(r)
    
    two_step_backward(enc.parameters(), feat, dec.parameters(), loss)
    
    return grad_stats(enc), grad_stats(dec)

In [10]:
enc = Encoder()
dec = Decoder()

# freeze enc for cloning method
freeze_enc = copy.deepcopy(enc)
for p in freeze_enc.parameters():
    p.requires_grad_(False)

In [11]:
B = 4
x = torch.rand(B, 3)

In [12]:
zero_grad()
baseline(x)

((1.2913068532943726, 1.9068480730056763),
 (0.006920290645211935, 1.1694526672363281))

In [13]:
zero_grad()
cloning(x)

((0.15245112776756287, 0.48006945848464966),
 (0.006920290645211935, 1.1694526672363281))

In [14]:
zero_grad()
freezing(x)

((0.15245112776756287, 0.48006945848464966),
 (0.006920290645211935, 1.1694526672363281))

In [15]:
zero_grad()
grading(x)

((0.15245112776756287, 0.48006945848464966),
 (0.006920290645211935, 1.1694526672363281))