In [70]:
from convnext import get_convnext
from interp_utils import register_hook, remove_hooks
from imagenet_val import ValData, simple_labels
import torch
import torch.nn as nn
import torch.nn.functional as F

val = ValData()

device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = get_convnext()
model.eval().to(device);

In [71]:
def label_smoothing_loss(logits, targets, smoothing=0.1):
    logprobs = F.log_softmax(logits, dim=-1)
    loss = (1-smoothing)*F.cross_entropy(logits, targets)-smoothing*logprobs.mean()
    return loss

In [89]:
def gelu_to_relu_layer(model, layer):
    remove_hooks(model, quiet=True)
    def gelu_to_relu(module, input, output):
        output = F.relu(input[0])
        return output
    handle = register_hook(model.blocks[layer].act, gelu_to_relu)
    return handle

In [96]:
from tqdm import tqdm

BATCH_SIZE = 100

ls_losses = torch.zeros(18)
ce_losses = torch.zeros(18)

for layer_idx in range(18):
    print(f"Layer {layer_idx}")
    handle = gelu_to_relu_layer(model, layer_idx)
    val_set_pbar = tqdm(val.as_batches(BATCH_SIZE), total=len(val)//BATCH_SIZE,)
    val_set_pbar.set_description(f"Layer {layer_idx}")
    ls_batch_losses = []
    ce_batch_losses = []
    for i, batch in enumerate(val_set_pbar):
        batch_data = batch.data.to(device)
        targets = torch.tensor(batch.class_ids).to(device)
        with torch.no_grad():
            logits = model(batch_data)
            ls_loss = label_smoothing_loss(logits, targets)
            ce_loss = F.cross_entropy(logits, targets)
        ls_batch_losses.append(ls_loss)
        ce_batch_losses.append(ce_loss)
        val_set_pbar.set_description(f"Layer {layer_idx} LS: {ls_loss:.3f} CE: {ce_loss:.3f}")
    ls_losses[layer_idx] = torch.stack(ls_batch_losses).mean()
    ce_losses[layer_idx] = torch.stack(ce_batch_losses).mean()


Layer 0


Layer 0 LS: 2.380 CE: 1.639: 100%|██████████| 500/500 [03:27<00:00,  2.41it/s]


Layer 1


Layer 1 LS: 2.508 CE: 1.781: 100%|██████████| 500/500 [03:37<00:00,  2.30it/s]


Layer 2


Layer 2 LS: 2.353 CE: 1.604: 100%|██████████| 500/500 [03:52<00:00,  2.15it/s]


Layer 3


Layer 3 LS: 2.307 CE: 1.552: 100%|██████████| 500/500 [03:51<00:00,  2.16it/s]


Layer 4


Layer 4 LS: 2.250 CE: 1.482: 100%|██████████| 500/500 [03:57<00:00,  2.11it/s]


Layer 5


Layer 5 LS: 2.157 CE: 1.393: 100%|██████████| 500/500 [03:53<00:00,  2.14it/s]


Layer 6


Layer 6 LS: 2.144 CE: 1.382: 100%|██████████| 500/500 [04:07<00:00,  2.02it/s]


Layer 7


Layer 7 LS: 2.117 CE: 1.344: 100%|██████████| 500/500 [04:08<00:00,  2.01it/s]


Layer 8


Layer 8 LS: 2.141 CE: 1.366: 100%|██████████| 500/500 [04:30<00:00,  1.85it/s]


Layer 9


Layer 9 LS: 2.152 CE: 1.392: 100%|██████████| 500/500 [04:03<00:00,  2.05it/s]


Layer 10


Layer 10 LS: 2.118 CE: 1.353: 100%|██████████| 500/500 [03:24<00:00,  2.44it/s]


Layer 11


Layer 11 LS: 2.118 CE: 1.334: 100%|██████████| 500/500 [03:22<00:00,  2.46it/s]


Layer 12


Layer 12 LS: 2.093 CE: 1.309: 100%|██████████| 500/500 [04:08<00:00,  2.01it/s]


Layer 13


Layer 13 LS: 2.100 CE: 1.317: 100%|██████████| 500/500 [03:59<00:00,  2.09it/s]


Layer 14


Layer 14 LS: 2.108 CE: 1.323: 100%|██████████| 500/500 [03:21<00:00,  2.48it/s]


Layer 15


Layer 15 LS: 2.095 CE: 1.312: 100%|██████████| 500/500 [03:22<00:00,  2.47it/s]


Layer 16


Layer 16 LS: 2.095 CE: 1.311: 100%|██████████| 500/500 [03:21<00:00,  2.48it/s]


Layer 17


Layer 17 LS: 2.096 CE: 1.313: 100%|██████████| 500/500 [03:22<00:00,  2.47it/s]


In [98]:
remove_hooks(model)

BATCH_SIZE = 100


val_set_pbar = tqdm(val.as_batches(BATCH_SIZE), total=len(val)//BATCH_SIZE,)
val_set_pbar.set_description(f"Layer {layer_idx}")
ls_batch_losses = []
ce_batch_losses = []
for i, batch in enumerate(val_set_pbar):
    batch_data = batch.data.to(device)
    targets = torch.tensor(batch.class_ids).to(device)
    with torch.no_grad():
        logits = model(batch_data)
        ls_loss = label_smoothing_loss(logits, targets)
        ce_loss = F.cross_entropy(logits, targets)
    ls_batch_losses.append(ls_loss)
    ce_batch_losses.append(ce_loss)
    val_set_pbar.set_description(f"Without ablation | LS: {ls_loss:.3f} CE: {ce_loss:.3f}")

ls_loss = torch.stack(ls_batch_losses).mean()
ce_loss = torch.stack(ce_batch_losses).mean()


Without ablation. LS: 2.096 CE: 1.312: 100%|██████████| 500/500 [03:23<00:00,  2.45it/s]


In [6]:
import plotly.express as px
import torch

data = torch.load('gelu_to_relu_results.pt')
ce_losses = data['ce_losses_by_relud_layer']
ce_loss = data['normal_ce_loss']

ls_losses = data['ls_losses_by_relud_layer']
ls_loss = data['normal_ls_loss']

fig = px.scatter(y=ce_losses - ce_loss, title="Cross Entropy Loss change")
fig.update_layout(xaxis_title="Layer", yaxis_title="Loss change after swapping gelus with relus")
fig.show()

fig = px.scatter(y=ls_losses - ls_loss, title="Label Smoothing Loss change")
fig.update_layout(xaxis_title="Layer", yaxis_title="Loss change after swapping gelus with relus")
fig.show()

In [114]:
data = {}
data['ls_losses_by_relud_layer'] = ls_losses.cpu().numpy()
data['ce_losses_by_relud_layer'] = ce_losses.cpu().numpy()
data['normal_ls_loss'] = ls_loss.cpu().numpy()
data['normal_ce_loss'] = ce_loss.cpu().numpy()

torch.save(data, 'gelu_to_relu_ablation_results.pt')