In [2]:
import os, gc, sys, time, random, math

import torch
import torch.nn as nn
import torch.nn.functional as F

# from typing import Optional, List

from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from diffusers import StableDiffusionPipeline
from tqdm import tqdm

from PIL import Image

import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split


In [2]:
class ImageDataset(Dataset):
    def __init__(self, root_dir, df, size = 224, center_crop = True):
        self.root_dir = root_dir
        self.files = df['file_name'].tolist()
        self.findings = df['text'].tolist()
        # self.tokenizer = tokenizer
        self.image_transforms = transforms.ToTensor()
        # transforms.Compose(
        #     [
        #         transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),
        #         transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),
        #         transforms.ToTensor(),
        #         transforms.Normalize([0.5], [0.5]),
        #     ]
        # )


    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        example = {}
        instance_image = Image.open(
            os.path.join(self.root_dir, self.files[idx])
        ).convert("RGB")

        example["instance_images"] = self.image_transforms(instance_image)
        example["instance_prompt_ids"] = self.findings[idx]
        # self.tokenizer(
        #     self.findings[idx],
        #     truncation=True,
        #     padding="max_length",
        #     max_length=self.tokenizer.model_max_length,
        #     return_tensors="pt",
        # ).input_ids

        return example

In [3]:
# The main path
Main_Path = '/home/mcrespo/migros_deepL'
# The datasets path under the main path
Data_storage = Main_Path + '/sample_flair'
save_result_path = Main_Path + '/selora_outputs'
reports_path = Data_storage + '/metadata.csv'
### folder to save the result.
folder_name = 'loras'

metadata = pd.read_csv(reports_path)
train_df, temp_df = train_test_split(metadata, test_size=0.2, random_state=42)
valid_df, test_df = train_test_split(temp_df, test_size=0.2, random_state=42)

train_ds = ImageDataset(
    root_dir=Data_storage,
    df=train_df
)

train_loader = DataLoader(train_ds, batch_size=2, shuffle=True, num_workers = 0)
### batch size determines the number of steps for each epoch, we are doing 100 epochs. So total number of steps is : 100 * train_df//batchsize

In [4]:
from safetensors.torch import load_file
model_unet = '/work/scratch/mcrespo/output/test_12_26_alldataset/12-29_17h45m15s/selora_outputs/loras/trained_model/final_Unet/diffusion_pytorch_model.safetensors'
model_text= '/work/scratch/mcrespo/output/test_12_26_alldataset/12-29_17h45m15s/selora_outputs/loras/trained_model/final_Text/model.safetensors'

model_text = load_file(model_text)
model_unet = load_file(model_unet)

In [44]:
import torch
import safetensors.torch

def merge_weights(tensor_dict):
    """
    Merges the weights for layers containing weight, lora_A, lora_B according to W' = W + AB^T.

    Args:
        tensor_dict (dict): Dictionary containing the tensor weights.

    Returns:
        dict: Updated tensor dictionary with merged weights.
    """
    merged_tensors = {}

    for key, tensor in tensor_dict.items():
        # Check if this is a weight tensor
        if "weight" in key and "lora_A" not in key and "lora_B" not in key:
            base_key = key.replace(".weight", "")

            lora_A_key = base_key + ".lora_A"
            lora_B_key = base_key + ".lora_B"

            if lora_A_key in tensor_dict and lora_B_key in tensor_dict:
                A = tensor_dict[lora_A_key]
                B = tensor_dict[lora_B_key]
                # Merge weights: W' = W + AB^T
                # (self.lora_B @ self.lora_A)

                merged_tensor = tensor + (B @ A)

                merged_tensors[base_key + ".weight"] = merged_tensor
            else:
                # Keep the original weight if no lora tensors are found
                merged_tensors[key] = tensor
                
        elif "lora_A" not in key and "lora_B" not in key:
        #     # Copy other tensors as is
            merged_tensors[key] = tensor

    return merged_tensors


# Merge the weights
merged_text_encoder = merge_weights(model_text)
merged_unet = merge_weights(model_unet)
# # Save the merged tensors back to a safetensors file
# output_path = "merged_file.safetensors"
# safetensors.torch.save_file(merged_tensor_dict, output_path)

# print(f"Merged weights saved to {output_path}")


In [47]:
output_path = '/work/scratch/mcrespo/output/test_12_26_alldataset/12-29_17h45m15s/selora_outputs/loras/trained_model/'
safetensors.torch.save_file(merged_text_encoder, output_path + 'final_Text/merged_model.safetensors')
safetensors.torch.save_file(merged_unet, output_path + 'final_Unet/merged_model.safetensors')

In [45]:
for name, tensor in merged_unet.items():
    print(name)
    print(tensor.shape)

conv_in.bias
torch.Size([320])
conv_in.weight
torch.Size([320, 4, 3, 3])
conv_norm_out.bias
torch.Size([320])
conv_norm_out.weight
torch.Size([320])
conv_out.bias
torch.Size([4])
conv_out.weight
torch.Size([4, 320, 3, 3])
down_blocks.0.attentions.0.norm.bias
torch.Size([320])
down_blocks.0.attentions.0.norm.weight
torch.Size([320])
down_blocks.0.attentions.0.proj_in.bias
torch.Size([320])
down_blocks.0.attentions.0.proj_in.weight
torch.Size([320, 320, 1, 1])
down_blocks.0.attentions.0.proj_out.bias
torch.Size([320])
down_blocks.0.attentions.0.proj_out.weight
torch.Size([320, 320, 1, 1])
down_blocks.0.attentions.0.transformer_blocks.0.attn1.to_k.weight
torch.Size([320, 320])
down_blocks.0.attentions.0.transformer_blocks.0.attn1.to_out.0.bias
torch.Size([320])
down_blocks.0.attentions.0.transformer_blocks.0.attn1.to_out.0.weight
torch.Size([320, 320])
down_blocks.0.attentions.0.transformer_blocks.0.attn1.to_q.weight
torch.Size([320, 320])
down_blocks.0.attentions.0.transformer_blocks.0.a

In [32]:
for name, tensor in model_text.items():
    print(name)
    print(tensor.shape)

text_model.embeddings.position_embedding.weight
torch.Size([77, 768])
text_model.embeddings.token_embedding.weight
torch.Size([49408, 768])
text_model.encoder.layers.0.layer_norm1.bias
torch.Size([768])
text_model.encoder.layers.0.layer_norm1.weight
torch.Size([768])
text_model.encoder.layers.0.layer_norm2.bias
torch.Size([768])
text_model.encoder.layers.0.layer_norm2.weight
torch.Size([768])
text_model.encoder.layers.0.mlp.fc1.bias
torch.Size([3072])
text_model.encoder.layers.0.mlp.fc1.lora_A
torch.Size([21, 768])
text_model.encoder.layers.0.mlp.fc1.lora_A_temp
torch.Size([22, 768])
text_model.encoder.layers.0.mlp.fc1.lora_B
torch.Size([3072, 21])
text_model.encoder.layers.0.mlp.fc1.lora_B_temp
torch.Size([3072, 22])
text_model.encoder.layers.0.mlp.fc1.weight
torch.Size([3072, 768])
text_model.encoder.layers.0.mlp.fc2.bias
torch.Size([768])
text_model.encoder.layers.0.mlp.fc2.lora_A
torch.Size([1, 3072])
text_model.encoder.layers.0.mlp.fc2.lora_A_temp
torch.Size([2, 3072])
text_model.