# Visualize lambda weights

Visualize the calculated lambda weights ot understand whether the weight network always favors the original rotation or not.
Reproduce and extend the image in the Appendix D.5, p23.

In [2]:
%cd DL2-2024/

/teamspace/studios/this_studio/DL2-2024


In [3]:
%load_ext autoreload
%autoreload 2

import wandb
run = wandb.init()

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.


[34m[1mwandb[0m: Currently logged in as: [33madamdivak[0m ([33mCV2-project[0m). Use [1m`wandb login --relogin`[0m to force relogin


In [63]:
wandb_models = {
    "lambda_equitune": {
        "model_link": "dl2-2024/dl-2024/Weighting_model:v0",
        "method": "equitune"
    },
    "lambda_equiattention": {
            "model_link": "dl2-2024/dl-2024/Weighting_model:v8",
            "method": "attention"
        }
}

from EquiCLIP.visualize_lambda import main as visualize_lambda_main

for model_name, details in wandb_models.items():
    artifact = run.use_artifact(details["model_link"], type='model')
    artifact_dir = artifact.download()
    model_files[model_name] = artifact.file()
    visualize_lambda_main([
        "--dataset_name", "CIFAR100",
        "--method", details["method"],
        "--group_name", "rot90",
        "--data_transformations", "rot90",
        "--model_file", model_files[model_name],
        "--model_display_name", model_name,
        "--output_filename_suffix", model_name
    ])


[34m[1mwandb[0m:   1 of 1 files downloaded.  
Global seed set to 0


Namespace(seed=0, device='cuda:0', img_num=0, num_prefinetunes=10, data_transformations='rot90', group_name='rot90', method='equitune', model_name='RN50', dataset_name='CIFAR100', verbose=True, softmax=False, use_underscore=False, load=False, full_finetune=False, visualize_features=False, model_file='./artifacts/Weighting_model:v0/CIFAR100_RN50_aug_rot90_eq_rot90_steps_10.pt', output_filename_suffix='lambda_equitune', model_display_name='lambda_equitune')
Model parameters: 102,007,137
Input resolution: 224
Context length: 77
Vocab size: 49408
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
loaded zeroshot weights!
Loading model from ./artifacts/Weighting_model:v0/CIFAR100_RN50_aug_rot90_eq_rot90_steps_10.pt


100%|██████████| 834/834 [00:45<00:00, 18.15it/s]


          0   90       180  270 model_name model_display_name dataset_name  \
0  1.000000  0.0  0.000000  0.0       RN50    lambda_equitune     CIFAR100   
1  0.999023  0.0  0.001108  0.0       RN50    lambda_equitune     CIFAR100   
2  1.000000  0.0  0.000000  0.0       RN50    lambda_equitune     CIFAR100   
3  1.000000  0.0  0.000000  0.0       RN50    lambda_equitune     CIFAR100   
4  1.000000  0.0  0.000013  0.0       RN50    lambda_equitune     CIFAR100   

  group_name data_transformations  full_finetune    method  
0      rot90                rot90          False  equitune  
1      rot90                rot90          False  equitune  
2      rot90                rot90          False  equitune  
3      rot90                rot90          False  equitune  
4      rot90                rot90          False  equitune  


In [68]:
%matplotlib inline
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import plotly.express as px
from pathlib import Path
pd.options.plotting.backend = "plotly"

def plot_all_weights(df, output_dir):
    dataset_name = df["dataset_name"].unique()
    assert len(dataset_name) == 1
    dataset_name = dataset_name[0]

    # Normalize weights
    # The raw values don't matter, as in the end the lambda * feature values are divided by their sum, 
    # so essentially lambda values are normalized
    group_columns = ["0", "90", "180", "270"]
    other_columns = set(df.columns) - set(group_columns)
    df[group_columns] = df[group_columns].div(df[group_columns].sum(axis=1), axis=0)

    df_statistics = df.groupby("model_display_name")[group_columns].agg(["mean", "std"])
    df_statistics = df_statistics.stack(level=0).reset_index().rename({"level_1": "group_transformation"}, axis=1)
    fig = px.bar(
        df_statistics, 
        x="group_transformation",
        y="mean", 
        error_y="std", 
        facet_col="model_display_name",
        title=f"Normalized lambda weight values for each input of {dataset_name}",
        labels={
            "model_display_name": "Model", 
            "value": "Lambda weights mean±std", 
            "group_transformation": "Group transformation (rotation, deg)"},
    )
    fig.write_image(f"{output_dir}/lamba_weight_means.png")
    display(fig)

    fig = df[group_columns + ["model_display_name"]].plot(
        kind='box', 
        title=f"Normalized lambda weight values for each input of {dataset_name}",
        labels={
            "model_display_name": "Model", 
            "value": "Lambda weights", 
            "variable": "Group transformation (rotation, deg)"},
        facet_col="model_display_name"
    )
    fig.write_image(f"{output_dir}/lamba_weight_box.png")
    display(fig)

    fig = df[group_columns + ["model_display_name"]].plot(
        kind='histogram', 
        #title=f"Normalized lambda weight values for each input of {dataset_name}",
        labels={
            "model_display_name": "Model", 
            "value": "Lambda weights", 
            "variable": "Group"},
        facet_col="model_display_name",
        facet_row="variable"
    )
    fig.write_image(f"{output_dir}/lamba_weight_histogram.png")
    display(fig)

    # FIXME update this per model name
    df_nonstandard_rotation_has_highest_weight = df[df["0"] < df[["90", "180", "270"]].max(axis=1)]
    ratio_nonstandard_rotation_has_highest_weight = df_nonstandard_rotation_has_highest_weight.shape[0] / df.shape[0]
    print(f"For {ratio_nonstandard_rotation_has_highest_weight * 100 :.2f}% of samples the highest lambda weight is not for the original rotation")

output_dir = Path("results/lambda_weights")
all_dfs = []
for df_path in output_dir.glob("*.csv"):
    df = pd.read_csv(df_path, index_col=0)
    all_dfs.append(df)
df = pd.concat(all_dfs)

df_statistics = plot_all_weights(df, output_dir)
df_statistics
#df





For 29.34% of samples the highest lambda weight is not for the original rotation


# 