In [1]:
import re
import io
import os
import sys
import math
import requests

import numpy as np

import torch
import torch.nn.functional as F
import torchvision.transforms as T
import torchvision.datasets as datasets
from torchvision.utils import save_image
import torchvision.transforms.functional as TF

from pathlib import Path
from datasets import load_dataset
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader, random_split


In [2]:
torch.cuda.empty_cache()

In [3]:
device = torch.device('cuda:0')

In [4]:
seed = 854
torch.manual_seed(seed)

<torch._C.Generator at 0x7fc85c2d2670>

## Bits Encoder

In [5]:
def encoder(binary_input):
    return int(binary_input, 2)

In [6]:
bits = 8
bits_vocab_len = 2**bits

print(f"bits vocab len: {bits_vocab_len}")

bits vocab len: 256


In [7]:
# Open the file in binary read modef
with open('../data/midi/MMD_MIDI/0/0/0/00000ec8a66b6bd2ef809b0443eeae41.mid', 'rb') as file:
    file_bytes = file.read()

# If you want to see the byte representation in a more readable format,
# you can iterate over the bytes object and print each byte.
byte_stream = ""
for byte in file_bytes:
    byte_stream += f'{byte:08b}'


In [8]:
len(byte_stream)//bits

870

In [9]:
a = []

num_parts = len(byte_stream)//bits
for i in range(num_parts):
    a.append(encoder(byte_stream[i*bits:(i+1)*bits]))

In [10]:
def file_to_binary(path):
    with open(path, 'rb') as file:
        file_bytes = file.read()

    byte_stream = ""
    for byte in file_bytes:
        byte_stream += f'{byte:08b}'

## GPT

In [None]:
from transformers import GPT2LMHeadModel, GPT2Tokenizer

# Load pre-trained GPT-2 model and tokenizer
llm = "gpt2"
model = GPT2LMHeadModel.from_pretrained(llm)
llm_tokenizer = GPT2Tokenizer.from_pretrained(llm)

In [None]:
embeddings = model.lm_head.weight
# embedding_matrix = model.transformer.wte.weight
llm_feature_dim = model.config.hidden_size
llm_vocab_len = model.config.vocab_size
model.to(device)
model.eval()

In [None]:
for param in model.parameters():
    param.requires_grad = False

In [None]:
print("gpt2 feature dim length:", llm_feature_dim)
print("gpt2 vocabulary length:", llm_vocab_len)
print("gpt2 embedding shape:", embeddings.shape)

## Google Gemma

In [11]:
from transformers import AutoTokenizer, AutoModelForCausalLM

llm = "google/gemma-2b"
tokenizer = AutoTokenizer.from_pretrained(llm)
model = AutoModelForCausalLM.from_pretrained(llm, device_map="auto")

Gemma's activation function should be approximate GeLU and not exact GeLU.
Changing the activation function to `gelu_pytorch_tanh`.if you want to use the legacy `gelu`, edit the `model.config` to set `hidden_activation=gelu`   instead of `hidden_act`. See https://github.com/huggingface/transformers/pull/29402 for more details.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [12]:
model.config

GemmaConfig {
  "_name_or_path": "google/gemma-2b",
  "architectures": [
    "GemmaForCausalLM"
  ],
  "attention_bias": false,
  "attention_dropout": 0.0,
  "bos_token_id": 2,
  "eos_token_id": 1,
  "head_dim": 256,
  "hidden_act": "gelu",
  "hidden_activation": null,
  "hidden_size": 2048,
  "initializer_range": 0.02,
  "intermediate_size": 16384,
  "max_position_embeddings": 8192,
  "model_type": "gemma",
  "num_attention_heads": 8,
  "num_hidden_layers": 18,
  "num_key_value_heads": 1,
  "pad_token_id": 0,
  "rms_norm_eps": 1e-06,
  "rope_scaling": null,
  "rope_theta": 10000.0,
  "torch_dtype": "bfloat16",
  "transformers_version": "4.39.3",
  "use_cache": true,
  "vocab_size": 256000
}

In [13]:
embeddings = model.lm_head.weight
print(embeddings.shape)

torch.Size([256000, 2048])


In [14]:
llm_feature_dim = model.config.hidden_size
llm_vocab_len = model.config.vocab_size
# model.to(device)
model.eval()

GemmaForCausalLM(
  (model): GemmaModel(
    (embed_tokens): Embedding(256000, 2048, padding_idx=0)
    (layers): ModuleList(
      (0-17): 18 x GemmaDecoderLayer(
        (self_attn): GemmaAttention(
          (q_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (k_proj): Linear(in_features=2048, out_features=256, bias=False)
          (v_proj): Linear(in_features=2048, out_features=256, bias=False)
          (o_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (rotary_emb): GemmaRotaryEmbedding()
        )
        (mlp): GemmaMLP(
          (gate_proj): Linear(in_features=2048, out_features=16384, bias=False)
          (up_proj): Linear(in_features=2048, out_features=16384, bias=False)
          (down_proj): Linear(in_features=16384, out_features=2048, bias=False)
          (act_fn): PytorchGELUTanh()
        )
        (input_layernorm): GemmaRMSNorm()
        (post_attention_layernorm): GemmaRMSNorm()
      )
    )
    (norm): GemmaRMSNo

In [15]:
for param in model.parameters():
    param.requires_grad = False

In [16]:
print("llm feature dim length:", llm_feature_dim)
print("llm vocabulary length:", llm_vocab_len)
print("llm embedding shape:", embeddings.shape)

llm feature dim length: 2048
llm vocabulary length: 256000
llm embedding shape: torch.Size([256000, 2048])


## Mapper Network

In [17]:
import torch.nn as nn
import torch.optim as optim
import torch.autograd as autograd

In [18]:
class TokenMapper(nn.Module):
    def __init__(self, input_dim, output_dim, device="cpu"):
        super().__init__()
        self.mapper = nn.Linear(input_dim, output_dim)
        self.mapper.to(device)

    def forward(self, one_hot_token):
        return self.mapper(one_hot_token)

In [19]:
# Create the mapper
# mapper maps vocabulary_size of target modality to feature_dimension size of llm
# mapper = TokenMapper(midi_vocab_len, llm_feature_dim, device=device)
mapper = TokenMapper(bits_vocab_len, llm_feature_dim, device=device)

In [20]:
mapper

TokenMapper(
  (mapper): Linear(in_features=256, out_features=2048, bias=True)
)

In [21]:
reverseMapper = TokenMapper(bits_vocab_len, llm_feature_dim, device=device)

In [22]:
reverseMapper

TokenMapper(
  (mapper): Linear(in_features=256, out_features=2048, bias=True)
)

## Prompt Network

In [23]:
prompt_len = 0

In [24]:
class Prompt(nn.Module):
    def __init__(self, input_dim, output_dim, device="cpu"):
        super().__init__()
        self.model = nn.Linear(input_dim, output_dim, bias=False)
        self.model.to(device)

    def forward(self, one_hot_token):
        return self.model(one_hot_token)

In [25]:
if prompt_len!=0:
    prompt = Prompt(prompt_len, llm_feature_dim, device=device)
    prompt_inputs = F.one_hot(torch.arange(prompt_len), num_classes=prompt_len).float().to(device)

## Generate Ground Truth

In [26]:
def generate_next_token_predictions(token_sequences):
    
    # Get model predictions
    with torch.no_grad():
        outputs = model(input_ids=token_sequences, output_hidden_states=True)
    
    return outputs.hidden_states[-1]

In [27]:
def generate_next_token_predictions_withfv(token_fv):
    
    # Get model predictions
    with torch.no_grad():
        outputs = model(inputs_embeds=token_fv, output_hidden_states=True)
    
    return outputs.hidden_states[-1]

In [28]:
def translate(batch_feature_vectors, embeddings, temperature=1.0):
    batch_size, seq_len, embedding_dim = batch_feature_vectors.shape

    # Normalize the embedding matrix
    embedding_matrix_norm = F.normalize(embeddings, dim=1)

    batch_feature_vector_norm = F.normalize(batch_feature_vectors, dim=2)
    cosine_similarities = torch.matmul(batch_feature_vector_norm, embedding_matrix_norm.T)
    # cosine_similarities = torch.matmul(batch_feature_vectors, embeddings.T)
    sfmx = torch.softmax(cosine_similarities/temperature, dim=2)
    closest_tokens = torch.argmax(sfmx, dim=2)
    
    mm = torch.matmul(sfmx, embeddings)

    return mm, cosine_similarities, closest_tokens

## Reinforce Loss

In [29]:
def Reinforce_Loss(logits, translated, loss, gamma=0.9, alpha=1, temperature=1):
    """
    Calculate the REINFORCE loss for sequence prediction.

    :param logits: Logits from the model, shape [batch_size, seq_len, vocab_size].
    :param targets: Ground truth sequence, shape [batch_size, seq_len].
    :param rewards: Reward for each step in the sequence, shape [batch_size, seq_len].
    :param gamma: Discount factor for future rewards.
    :return: The REINFORCE loss (to be maximized).
    """
    batch_size, seq_len, _ = logits.shape
    translated = translated.to(torch.int64)
    # shape = [batch_size, seq_len, llm_vocab_len]
    log_probs = F.log_softmax(logits/temperature, dim=-1)
    log_probs_targets = log_probs.gather(2, translated.unsqueeze(2)).squeeze(2)
    
    # Create a discount matrix
    discount_matrix = torch.zeros((seq_len, seq_len)).to(device)

    # Fill the matrix according to the given pattern
    for i in range(seq_len):
        for j in range(i, seq_len):
            discount_matrix[i, j] = gamma ** (j - i)

    normalize_factor = discount_matrix.sum(dim=1)
    
    # Calculate discounted rewards
    discounted_loss = loss.unsqueeze(1) * discount_matrix
    cumulative_loss = discounted_loss.sum(dim=-1) / normalize_factor / alpha
    
    # Calculate loss
    total_loss = torch.sum(log_probs_targets * cumulative_loss) / batch_size / seq_len

    del discount_matrix
    del cumulative_loss
    del log_probs_targets
    del normalize_factor
    del log_probs
    
    return total_loss

## Get Dataset

In [30]:
import numpy as np
import matplotlib.pyplot as plt

from torch.utils.data import DataLoader, Dataset
from torch.utils.data.sampler import BatchSampler, SequentialSampler

In [31]:
bits = 8

In [32]:
batch_size = 10
image_size = 128
seq_len = 256

### Midi Bits

In [33]:
# Function to process a single file
def process_file(file_path):
    with open(file_path, 'rb') as file:
        file_bytes = file.read()
    
    byte_stream = ""
    for byte in file_bytes:
        byte_stream += f'{byte:08b}'
    
    a = []
    num_parts = len(byte_stream) // bits
    for i in range(num_parts):
        a.append(encoder(byte_stream[i*bits:(i+1)*bits]))
    
    return a

# Main function to walk through the directory and process each file
def process_directory(file_path):
    supported_extensions = ('.mid', '.midi')
    all_integers = []

    for root, dirs, files in os.walk(file_path):
        for file in files:
            if file.endswith(supported_extensions):
                full_path = os.path.join(root, file)
                file_integers = process_file(full_path)
                all_integers.extend(file_integers)
    
    return all_integers

In [34]:

file_path = '../data/midi/Maestro'
# Call the function with your specific file path
resulting_integers = process_directory(file_path)

In [35]:
len(resulting_integers)

83869091

In [36]:
class IntegerDataset(Dataset):
    """Dataset wrapping integers."""

    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]


In [37]:
class GroupedBatchSampler(BatchSampler):
    """BatchSampler that groups multiple batches together."""

    def __init__(self, dataset, batch_size=128, group_size=seq_len, shuffle=False):
        # If shuffle is True, use a RandomSampler instead of SequentialSampler
        sampler = SequentialSampler(dataset)
        super().__init__(sampler, batch_size=batch_size * group_size, drop_last=False)
        self.group_size = group_size

    def __iter__(self):
        batch = []
        for idx in self.sampler:
            batch.append(idx)
            # When we've collected enough indices for a group of batches, yield them
            if len(batch) == self.batch_size:
                yield batch
                batch = []
        if batch and not self.drop_last:
            yield batch


In [38]:
# Setup your dataset
dataset = IntegerDataset(resulting_integers)

# Create the GroupedBatchSampler
grouped_batch_sampler = GroupedBatchSampler(dataset, batch_size=batch_size, group_size=seq_len, shuffle=True)

# Create the DataLoader using the custom sampler
dataloader = DataLoader(dataset, batch_sampler=grouped_batch_sampler)

In [39]:
for grouped_batch in dataloader:
    # `grouped_batch` is now a tensor with the shape [10, 128]
    # Perform your training operations here
    break

In [40]:
grouped_batch.shape

torch.Size([2560])

### Image Bits

In [None]:
from PIL import Image
from dall_e import map_pixels

In [None]:
class ToBinaryString:
    """Converts a PIL image or a tensor to a binary string."""
    def __call__(self, pic):
        # Convert PIL Image to Tensor if it's not already a tensor
        if not isinstance(pic, torch.Tensor):
            transform_to_tensor = T.ToTensor()
            pic = transform_to_tensor(pic)
        
        # Ensure the tensor is in CPU and in uint8 format
        pic = (pic * 255).byte()
        
        # Convert the image tensor to a numpy array and then to bytes
        image_bytes = pic.numpy().tobytes()
        
        # Convert bytes to binary string
        binary_string = ''.join(f'{byte:08b}' for byte in image_bytes)
        integers = [int(binary_string[i:i+bits], 2) for i in range(0, len(binary_string), bits)]
        
        return integers

In [None]:
def resize_and_crop(img):
    # Resize while maintaining aspect ratio and center crop
    s = min(img.size)
    r = image_size / s
    s = (round(r * img.size[1]), round(r * img.size[0]))
    img = TF.resize(img, s, interpolation=Image.LANCZOS)
    img = TF.center_crop(img, output_size=2 * [image_size])
    return img

def modified_map_pixels(img):
    # Add a batch dimension, apply map_pixels, and then remove the batch dimension
    img = img.unsqueeze(0)
    img = map_pixels(img)
    return img.squeeze(0)

# Now, include your custom transform in the Compose
transform = T.Compose([
    T.Lambda(resize_and_crop),
    T.ToTensor()
    # ToBinaryString(),
])

In [None]:
lsun_path = '../data/lsun'

dataset = datasets.LSUN(root=lsun_path, classes=['classroom_train'], transform=transform)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

In [None]:
print("dataset size:", len(dataloader))

In [None]:
# Directory where you want to save the temporary .jpg files
temp_dir = 'temp_images'
os.makedirs(temp_dir, exist_ok=True)

In [None]:
def image_to_binary(file_path):
    """
    Convert an image file to its binary representation.
    """
    with open(file_path, 'rb') as file:
        binary_data = file.read()
    
    binary_string = ''.join(format(byte, '08b') for byte in binary_data)
    integers = [int(binary_string[i:i+bits], 2) for i in range(0, len(binary_string), bits)]
        
    return integers

## Train Model

In [41]:
# Hyper Parameters
learning_rate = 1e-5
epochs = 1
gamma = 0.1
temperature = 0.001
alpha = 1

In [42]:
experiment = "base_single"
algo = "base"
exp_type = "midi"
name = f"{bits}bits"
experiment_name = f"{exp_type}/{algo}/{experiment}/{name}/{llm}/lr={learning_rate},gamma={gamma},temp={temperature},promptlen={prompt_len}"

experiment_name

'midi/base/base_single/8bits/google/gemma-2b/lr=1e-05,gamma=0.1,temp=0.001,promptlen=0'

### Load Models

In [None]:
# mapper.load_state_dict(torch.load(f"../models/{experiment_name}/mapper.pt"))
# reverseMapper.load_state_dict(torch.load(f"../models/{experiment_name}/reversemapper.pt"))

### Training

In [43]:
from torch.utils.tensorboard import SummaryWriter

# Create a SummaryWriter instance (logs will be saved in 'runs' folder)
writer = SummaryWriter(log_dir = f'../runs/{experiment_name}')

In [44]:
criterion = nn.CrossEntropyLoss()
if prompt_len==0 and "dual" in experiment_name:
    optimizer = optim.Adam(list(mapper.parameters()) + list(reverseMapper.parameters()), lr=learning_rate)
elif "dual" in experiment_name:
    optimizer = optim.Adam(list(mapper.parameters()) + list(prompt.parameters()), lr=learning_rate)
elif "single" in experiment_name:
    optimizer = optim.Adam(mapper.parameters(), lr=learning_rate)
rl_criterion = nn.CrossEntropyLoss(reduction='none')
scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=gamma)

In [45]:

for epoch in range(epochs):
    mapper.train()
    # mapper.eval()
    for i, dd in enumerate(dataloader):
        
        # optimizer.zero_grad()

        # images = dd[0]

        # tokens = []
        # for image_idx, image in enumerate(images):
        #     # Construct a unique filename for each image
        #     filename = os.path.join(temp_dir, f'batch_{i}_image_{image_idx}.jpg')
        #     # Save the image as a .jpg file
        #     save_image(image, filename)
        #     # Convert the saved image file to binary representation
        #     binary_representation = image_to_binary(filename)
        #     tokens.append(binary_representation)
        #     # Optionally, delete the file if it's no longer needed
        #     os.remove(filename)

        # tensor_list = [torch.tensor(sublist) for sublist in tokens]
        # # Pad the sequence of tensors, padding zeros behind each sequence
        # data = pad_sequence(tensor_list, batch_first=True, padding_value=0)        
        # num_chunks = data.shape[1] // seq_len
        
        # data = data[:,:num_chunks*seq_len]
        data = dd
        
        ground_truth_tokens = data.reshape(-1, seq_len).to(device)
        one_hot_tokens = F.one_hot(ground_truth_tokens, num_classes=bits_vocab_len).float()

        # Logits are to be compared with the next ground truth tokens
        ground_truth_tokens = ground_truth_tokens[:,1:]
        inputs_feature_vector = mapper(one_hot_tokens)
        
        # Add prompt to input
        # prompt_feature_vector = prompt(prompt_inputs)
        # prompt_feature_vector = prompt_feature_vector.unsqueeze(0).repeat(batch_len, 1, 1)
        # inputs_feature_vector = torch.cat((prompt_feature_vector, mapped_feature_vector), dim=1)

        # Map tokens and get ground truth from LLM
        translated_feature_vector, translated_logits, translated_text_tokens = translate(inputs_feature_vector, embeddings.detach(), temperature=temperature)
        # translated_feature_vector, translated_logits, translated_text_tokens = translate(inputs_feature_vector, embeddings.detach(), temperature=temperature)

        # Calculate Representation of Last Layer in LLM
        final_layer_fv = generate_next_token_predictions_withfv(translated_feature_vector)
        # final_layer_fv = generate_next_token_predictions(translated_text_tokens.long()).to(device)

        # Calculate Logits with mapper function
        # final_layer_fv = F.normalize(final_layer_fv, dim=-1)
        # mapper_embeds = F.normalize(mapper.mapper.weight, dim=0)
        # logits = torch.matmul(final_layer_fv, reverseMapper.mapper.weight)
        logits = torch.matmul(final_layer_fv, mapper.mapper.weight)
        # logits = logits[:,prompt_len:-1]
        logits = logits[:,:-1]
        logits_ = logits.reshape(-1, bits_vocab_len)
        ground_truth_tokens = ground_truth_tokens.reshape(-1)        
        ce_loss = criterion(logits_, ground_truth_tokens)
        ce_loss.backward(retain_graph=True)
        optimizer.step()
        if 'base' in algo: 
            ce_loss.backward()
            optimizer.step()
            writer.add_scalar("training/cross_entropy_base", ce_loss.item(), epoch*len(dataloader)+i)
            if i%50==0:
                print(f"Epoch {epoch+1}, Batch {i}, CE Loss: {ce_loss.mean().item()}")
        # RL Loss
        if 'rl' in algo:
            optimizer.zero_grad()
            # action_logits = torch.matmul(mapped_feature_vector, embeddings.T.detach())
            translated_feature_vector, translate_logits, translated_text_tokens = translate(inputs_feature_vector, embeddings.detach(), temperature=temperature)
            with torch.no_grad():
                
                final_layer_fv = generate_next_token_predictions_withfv(translated_feature_vector)
                logits = torch.matmul(final_layer_fv, reverseMapper.mapper.weight)
                logits = logits[:,prompt_len:-1]  

                
                logits_ = logits.reshape(-1, bits_vocab_len)
                ce_loss = rl_criterion(logits_, ground_truth_tokens)
                ce_loss = ce_loss.reshape(-1, logits.size(1))
                
            rl_loss = Reinforce_Loss(translate_logits[:,1:], translated_text_tokens[:,1:].detach(), ce_loss, alpha=alpha, gamma=gamma, temperature=temperature)
            
            rl_loss.backward()
            optimizer.step()
            # writer.add_scalar("training_rl", rl_loss.item(), epoch*len(midi_loader)+i)
            # Log the losses
            writer.add_scalars(
                "training",
                {
                    "rl_loss": rl_loss.item(),
                    "cross_entropy_rl": ce_loss.mean().item(),
                },
                epoch * len(dataloader) + i + 1000 + 1000
            )

            torch.cuda.empty_cache()

            if i % 50 == 0:
                print(f"Epoch {epoch+1}, Batch {i}, CE Loss: {ce_loss.mean().item()}, RL Loss: {rl_loss.item()}")

    scheduler.step()
    print(f"Epoch {epoch+1}/{epochs} completed.")
writer.close()

OutOfMemoryError: CUDA out of memory. Tried to allocate 2.44 GiB. GPU 0 has a total capacty of 10.91 GiB of which 254.12 MiB is free. Including non-PyTorch memory, this process has 10.66 GiB memory in use. Of the allocated memory 10.09 GiB is allocated by PyTorch, and 6.80 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

In [None]:
Path(f"../models/{experiment_name}").mkdir(parents=True, exist_ok=True)
torch.save(mapper.state_dict(), f"../models/{experiment_name}/mapper.pt")
torch.save(reverseMapper.state_dict(), f"../models/{experiment_name}/reversemapper.pt")

In [None]:
writer.close()