# Weight datasets

In [1]:
%pwd

'/net/tscratch/people/plgkingak'

In [2]:
path = "weights2weights/weights_datasets"
device = "cuda"

In [3]:
import torch
import numpy as np
import pandas as pd
from datasets import Dataset, concatenate_datasets

In [4]:
import matplotlib.pyplot as plt

In [5]:
weight_dimensions = torch.load(f"{path}/weight_dimensions.pt")

In [6]:
dim_df = pd.DataFrame(weight_dimensions).T

In [None]:
dim_df

Unnamed: 0,0,1
base_model.model.down_blocks.0.attentions.0.transformer_blocks.0.attn1.to_q.lora_A.weight,"(320,)","(1, 320)"
base_model.model.down_blocks.0.attentions.0.transformer_blocks.0.attn1.to_q.lora_B.weight,"(320,)","(320, 1)"
base_model.model.down_blocks.0.attentions.0.transformer_blocks.0.attn1.to_v.lora_A.weight,"(320,)","(1, 320)"
base_model.model.down_blocks.0.attentions.0.transformer_blocks.0.attn1.to_v.lora_B.weight,"(320,)","(320, 1)"
base_model.model.down_blocks.0.attentions.0.transformer_blocks.0.attn2.to_q.lora_A.weight,"(320,)","(1, 320)"
...,...,...
base_model.model.up_blocks.3.attentions.2.transformer_blocks.0.attn1.to_v.lora_B.weight,"(320,)","(320, 1)"
base_model.model.up_blocks.3.attentions.2.transformer_blocks.0.attn2.to_q.lora_A.weight,"(320,)","(1, 320)"
base_model.model.up_blocks.3.attentions.2.transformer_blocks.0.attn2.to_q.lora_B.weight,"(320,)","(320, 1)"
base_model.model.up_blocks.3.attentions.2.transformer_blocks.0.attn2.to_v.lora_A.weight,"(768,)","(1, 768)"


## Mid block

In [8]:
dim_df["use"] = 0
dim_df.loc[dim_df.index.str.contains("mid_block"), "use"] = 1

In [9]:
dim_df.use.sum()

8

In [10]:
dim_df["lengths"] = [a.numel() for a,_ in list(weight_dimensions.values())]

In [11]:
dim_df["start"] = dim_df.lengths.cumsum().shift(fill_value=0)
dim_df["end"] = dim_df.start + dim_df.lengths

In [12]:
dim_df[dim_df.use == 1]

Unnamed: 0,0,1,use,lengths,start,end
base_model.model.mid_block.attentions.0.transformer_blocks.0.attn1.to_q.lora_A.weight,"(1280,)","(1, 1280)",1,1280,35968,37248
base_model.model.mid_block.attentions.0.transformer_blocks.0.attn1.to_q.lora_B.weight,"(1280,)","(1280, 1)",1,1280,37248,38528
base_model.model.mid_block.attentions.0.transformer_blocks.0.attn1.to_v.lora_A.weight,"(1280,)","(1, 1280)",1,1280,38528,39808
base_model.model.mid_block.attentions.0.transformer_blocks.0.attn1.to_v.lora_B.weight,"(1280,)","(1280, 1)",1,1280,39808,41088
base_model.model.mid_block.attentions.0.transformer_blocks.0.attn2.to_q.lora_A.weight,"(1280,)","(1, 1280)",1,1280,41088,42368
base_model.model.mid_block.attentions.0.transformer_blocks.0.attn2.to_q.lora_B.weight,"(1280,)","(1280, 1)",1,1280,42368,43648
base_model.model.mid_block.attentions.0.transformer_blocks.0.attn2.to_v.lora_A.weight,"(768,)","(1, 768)",1,768,43648,44416
base_model.model.mid_block.attentions.0.transformer_blocks.0.attn2.to_v.lora_B.weight,"(1280,)","(1280, 1)",1,1280,44416,45696


In [13]:
indices = np.concatenate([np.arange(row.start, row.end) for _, row in dim_df[dim_df.use == 1].iterrows()])
len(indices)


9728

### Save

In [4]:
weights = torch.load(f"{path}/identities/all_weights.pt", torch.device(device))

In [None]:
torch.save(weights[:,indices], f'{path}/mid_block.pt')

In [4]:
dataset = torch.load(f'{path}/mid_block.pt', torch.device(device))

In [None]:
Dataset.from_dict({"data":dataset}).save_to_disk(f'{path}/mid_block')

In [8]:
d = Dataset.from_dict({"data":dataset})

In [None]:
d

In [None]:
len(d[0]["data"])

## Evaluation samples

In [None]:
weights = torch.load(f"{path}/identities/all_weights.pt", torch.device(device))

In [8]:
random_indices = torch.randperm(weights.size(0))[:10]
sampled_rows = weights[random_indices]

torch.save(sampled_rows, f'{path}/evaluation_samples.pt')

## Blondes (and brunettes)

In [26]:
import torch
df = torch.load('/net/tscratch/people/plgkingak/weights2weights/files/identity_df.pt')

In [27]:
df.reset_index(inplace=True)

In [28]:
blondes = df.loc[df["Blond_Hair"] == 1]
brunettes = df.loc[df["Black_Hair"] == 1]

In [None]:
blondes.sample(n=2000, random_state=42)
brunettes.sample(n=2000, random_state=42)

In [23]:
weights = torch.load(f"{path}/identities/all_weights.pt", torch.device(device))

In [None]:
blonde_weights = weights[blondes.sample(n=2000, random_state=42).index]
blonde_weights.shape

In [34]:
torch.save(blonde_weights.clone(), f"{path}/blondes.pt")

In [None]:
brunette_weights = weights[brunettes.sample(n=2000, random_state=42).index]
brunette_weights.shape

In [36]:
torch.save(brunette_weights.clone(), f"{path}/brunettes.pt")

## Split

In [8]:
import torch
df = torch.load('/net/tscratch/people/plgkingak/weights2weights/files/identity_df.pt')

In [9]:
weights = torch.load(f"{path}/identities/all_weights.pt", torch.device(device))

In [10]:
filenames = df.index.str.split(".").str[0].to_list()

In [11]:
name = "single"

In [12]:
weights.shape

torch.Size([64974, 99648])

### Check values

Maybe some columns are all 0, or all the same value?

In [None]:
zcnts = []
ucnts = []
zero = (weights == 0)
for i in range(weights.shape[1]):
    zcnts.append(zero[:,i].sum().item())
    value_counts = {}
    for val in weights[:, i]:
        val_item = val.item()
        if val_item in value_counts:
            value_counts[val_item] += 1
        else:
            value_counts[val_item] = 1
    ucnts.append(max(value_counts.values()))

In [32]:
sum(zcnts)

0

In [None]:
plt.hist(ucnts)

### If rescaling

In [9]:
min_val = weights.min()
max_val = weights.max()
# outmap = (weights - outmap_min) / (outmap_max - outmap_min)

In [None]:
min_val, max_val

In [None]:
mean = weights.mean()
std = weights.std()
mean, std

In [None]:
# weights.sub_(min_val).div_(max_val - min_val).mul_(2).sub_(1)
weights.sub_(mean).div_(std) 

In [12]:
name = "rescaled" # "single"

### Save

In [13]:
for i, row in enumerate(weights):
    filename = f"{path}/{name}/{filenames[i]}.pt"
    torch.save(row.clone(), filename) 

In [None]:
import os

data_dir =  f"{path}/{name}/"
file_list = sorted(os.listdir(data_dir)) 
len(file_list)

In [None]:
def data_generator():
    for filename in file_list:
        file_path = os.path.join(data_dir, filename)
        tensor = torch.load(file_path)  # Load tensor lazily
        yield {"filename": filename, "data": tensor.cpu().numpy()}  # Convert tensor to NumPy
        
dataset = Dataset.from_generator(data_generator, cache_dir=f"{path}/.cache")
dataset.save_to_disk(f"{path}/full_rescaled")

In [None]:
dataset.data

In [None]:
max(dataset[2]["data"])