In [6]:
from dataclasses import dataclass
from typing import Dict, List, Optional

import torch
from torch import nn, optim, utils
from torch.nn import functional as F
import lightning as L

from transformers.models.bart.modeling_bart import BartForConditionalGeneration, shift_tokens_right
from transformers import AutoTokenizer, BatchEncoding, PreTrainedTokenizerBase

from latent_models.perceiver_ae import PerceiverAutoEncoder

class LatentGenerator(L.LightningModule):
    def __init__(
            self,
            pretrained_model: BartForConditionalGeneration,
            autoencoder: PerceiverAutoEncoder,
            ):
        super().__init__()
        self.lm = pretrained_model
        self.autoencoder = autoencoder

        # freeze pretrained model
        self.lm.requires_grad_(False)

    # required
    def training_step(self, batch, batch_idx):
        x = self.lm.get_encoder()(input_ids=batch['input_ids'], attention_mask=batch['attention_mask'])
        x = self.autoencoder.encode(x['last_hidden_state'], attention_mask=batch['attention_mask'])
        x = self.autoencoder.decode(x)
        loss = self.lm(labels=batch['labels'], encoder_outputs=x).loss
        self.log('train_loss', loss)
        return loss
    
    # required
    def configure_optimizers(self):
        return optim.Adam(self.parameters(), lr=1e-3)
    
    # optional
    def validation_step(self, batch, batch_idx):
        x = self.lm.get_encoder()(input_ids=batch['input_ids'], attention_mask=batch['attention_mask'])
        x = self.autoencoder.encode(x['last_hidden_state'], attention_mask=batch['attention_mask'])
        x = self.autoencoder.decode(x)
        loss = self.lm(labels=batch['labels'], encoder_outputs=x).loss
        self.log('valid_loss', loss)
        return loss

class Diffusion(L.LightningModule):
    def __init__(
            self,
            diffusion_model,
            ):
        super().__init__()
        self.diffusion_model = diffusion_model

@dataclass
class DataCollatorForBartDenoisingLM:
    """
    Data collator used for BART denoising language modeling.

    Args:
        tokenizer (:class:`~transformers.PreTrainedTokenizer` or :class:`~transformers.PreTrainedTokenizerFast`):
            The tokenizer used for encoding the data
    """
    tokenizer: PreTrainedTokenizerBase
    decoder_start_token_id: int

    def __call__(self, examples: List[Dict[str, List[int]]]) -> BatchEncoding:
        batch = BatchEncoding(
            {k: torch.LongTensor([examples[i][k] for i in range(len(examples))]) for k, v in examples[0].items()}
        )
        batch["labels"] = batch["input_ids"].clone()
        batch["decoder_input_ids"] = shift_tokens_right(
            batch["labels"], self.tokenizer.pad_token_id, self.decoder_start_token_id
        )
        batch['labels'][batch['labels'] == self.tokenizer.pad_token_id] = -100
        batch["attention_mask"] = (batch["input_ids"] != self.tokenizer.pad_token_id).long()
        batch["decoder_attention_mask"] = (batch["decoder_input_ids"] != self.tokenizer.pad_token_id).long()

        return batch

In [8]:
enc_dec_model = 'facebook/bart-base'
max_seq_len = 64
dataset_name = 'roc'
train_batch_size = 32

pretrained_model = BartForConditionalGeneration.from_pretrained(enc_dec_model)
autoencoder = PerceiverAutoEncoder(
    dim_lm=pretrained_model.config.d_model,
    dim_ae=64,
    depth=3,
    num_encoder_latents=32,
    num_decoder_latents=32,
    max_seq_len=max_seq_len,
    transformer_decoder=True,
    l2_normalize_latents=False,
)
latent_generator = LatentGenerator(pretrained_model, autoencoder)
tokenizer = AutoTokenizer.from_pretrained(enc_dec_model)
model_config = pretrained_model.config

from datasets import load_dataset
import os
roc_data_path = 'datasets/ROCstory'
dataset = load_dataset("text", data_files={f'{split}': os.path.join(roc_data_path, f'roc_{split}.json') for split in ['train', 'valid']})
def process_roc_dataset(dataset):
    def extract_roc_text(example):
        text = example['text']
        assert text[:2] == '["'
        assert text[-2:] == '"]'
        sentences = text[2:-2]
        return {'text': sentences}
    dataset = dataset.map(extract_roc_text, )
    dataset = dataset.shuffle(seed=42)
    # Hold out some validation samples for testing
    val_test_ds = dataset['valid'].train_test_split(train_size=1000, shuffle=False)
    dataset['valid'] = val_test_ds['train']
    dataset['test'] = val_test_ds['test']
    return dataset
prep_dataset = process_roc_dataset(dataset)
map_prep_dataset = prep_dataset.map(lambda x: tokenizer(x['text'], padding="max_length", truncation=True, max_length=max_seq_len), batched=True, remove_columns=['text'], num_proc=4)
# map_prep_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'labels'])

from torch.utils.data import DataLoader
# collate_fn = lambda x: tokenizer.pad(x, return_tensors='pt')
collate_fn = DataCollatorForBartDenoisingLM(tokenizer, model_config.decoder_start_token_id)
train_dataloader = DataLoader(map_prep_dataset['train'], batch_size=32, shuffle=True, collate_fn=collate_fn, drop_last=True, num_workers=4)
valid_dataloader = DataLoader(map_prep_dataset['valid'], batch_size=32, shuffle=False, collate_fn=collate_fn, drop_last=True, num_workers=4)
test_dataloader = DataLoader(map_prep_dataset['test'], batch_size=32, shuffle=False, collate_fn=collate_fn, drop_last=True, num_workers=4)
# train_dataloader = DataLoader(prep_dataset['train'], batch_size=32, shuffle=True, collate_fn=collate_fn)
# valid_dataloader = DataLoader(prep_dataset['valid'], batch_size=32, shuffle=False, collate_fn=collate_fn)
# test_dataloader = DataLoader(prep_dataset['test'], batch_size=32, shuffle=False, collate_fn=collate_fn)

Map: 100%|██████████| 93161/93161 [00:01<00:00, 79028.13 examples/s]
Map: 100%|██████████| 5000/5000 [00:00<00:00, 78896.95 examples/s]
Map (num_proc=4): 100%|██████████| 93161/93161 [00:02<00:00, 40896.91 examples/s]
Map (num_proc=4): 100%|██████████| 1000/1000 [00:00<00:00, 8719.82 examples/s]
Map (num_proc=4): 100%|██████████| 4000/4000 [00:00<00:00, 21270.64 examples/s]


In [19]:
for i in train_dataloader:
    break

from torchinfo import summary

i = i.to('cuda')
result = summary(latent_generator.lm, input_data={**i}, device='cuda', depth=10)
input_ae = latent_generator.lm.get_encoder()(input_ids=i['input_ids'], attention_mask=i['attention_mask'])
input_ae_attn = i['attention_mask']
result_2 = summary(latent_generator.autoencoder, input_data={'encoder_outputs': input_ae['last_hidden_state'], 'attention_mask': input_ae_attn}, device='cuda', depth=10)
# output_ae = latent_generator.autoencoder(input_ae['last_hidden_state'], attention_mask=input_ae_attn)
# output_ae_attn = input_ae_attn

In [7]:
import torch

latent_model = torch.load('saved_latent_models/roc/2024-03-24_17-58-54/model_full.pt')
diffusion_model = torch.load('saved_diff_models/roc/2024-03-24_17-59-30/diffusion.pt')

  from .autonotebook import tqdm as notebook_tqdm


In [8]:
output_ae = latent_generator.autoencoder.encode(input_ae['last_hidden_state'], attention_mask=input_ae_attn)

NameError: name 'latent_generator' is not defined

In [36]:
mask = torch.ones(output_ae.shape[0], latent_model.num_encoder_latents, dtype=torch.bool).to(latent_model.device)

In [39]:
import math
import numpy as np
from functools import partial

def cosine_schedule(t, start = 0, end = 1, tau = 1, clip_min = 1e-9):
    power = 2 * tau
    v_start = math.cos(start * math.pi / 2) ** power
    v_end = math.cos(end * math.pi / 2) ** power
    output = torch.cos((t * (end - start) + start) * math.pi / 2) ** power
    output = (v_end - output) / (v_end - v_start)
    return output.clamp(min = clip_min)

def log(t, eps = 1e-12):
    return torch.log(t.clamp(min = eps))

def log_snr_to_alpha(log_snr):
    alpha = torch.sigmoid(log_snr)
    return alpha

def alpha_to_shifted_log_snr(alpha, scale = 1):
    return log((alpha / (1 - alpha))).clamp(min=-15, max=15) + 2*np.log(scale).item()

def time_to_alpha(t, alpha_schedule, scale):
    alpha = alpha_schedule(t)
    shifted_log_snr = alpha_to_shifted_log_snr(alpha, scale = scale)
    return log_snr_to_alpha(shifted_log_snr)


alpha_schedule = cosine_schedule
scale = 1.
train_schedule = partial(time_to_alpha, alpha_schedule=alpha_schedule, scale=scale)

In [44]:
diffusion_model.diffusion_model.num_classes

0

In [54]:
output_ae.shape, mask.shape

(torch.Size([32, 32, 64]), torch.Size([32, 32]))

In [49]:
diffusion_input = {
    'txt_latent': output_ae,
    'mask': mask,
    'class_id': 0,
}

In [51]:
result_3 = summary(diffusion_model, input_data=diffusion_input, device='cuda', depth=10)

In [52]:
with open('diffusion_model_summary.txt', 'w') as f:
    f.write(str(result_3))

In [57]:
result_4 = summary(latent_model, input_data={**i}, device='cuda', depth=10)

In [59]:
with open('latent_model_summary.txt', 'w') as f:
    f.write(str(result_4))

In [70]:
# result = summary(latent_model.get_encode(), input_data={**i}, device='cuda', depth=10)
input_ae = latent_model.get_encoder()(input_ids=i['input_ids'], attention_mask=i['attention_mask'])
input_ae_attn = i['attention_mask']
result_5 = summary(latent_model.perceiver_ae, input_data={'encoder_outputs': input_ae['last_hidden_state'], 'attention_mask': input_ae_attn}, device='cuda', depth=10)

In [72]:
with open('latent_model_ae_summary.txt', 'w') as f:
    f.write(str(result_5))

In [9]:
import torch

t = torch.ones(10, 23)
padding_dims = 5
t.view(*t.shape, *((1,) * padding_dims)).shape, t.ndim

(torch.Size([10, 23, 1, 1, 1, 1, 1]), 2)

In [13]:
batch = 32
times = torch.zeros((batch,)).float().uniform_(0, 1.)
times.shape, times

(torch.Size([32]),
 tensor([0.9520, 0.3880, 0.8722, 0.4655, 0.3569, 0.3600, 0.0963, 0.3076, 0.2223,
         0.1012, 0.7705, 0.4032, 0.8887, 0.2333, 0.4363, 0.0963, 0.3238, 0.1077,
         0.2151, 0.5173, 0.8530, 0.4075, 0.9194, 0.4000, 0.6558, 0.3870, 0.0626,
         0.8081, 0.0774, 0.8674, 0.9108, 0.4426]))

In [14]:
# noise = torch.randn_like(output_ae)
noise = torch.randn((32,32,64))
noise.shape

torch.Size([32, 32, 64])

In [15]:
alpha = diffusion_model.train_schedule(times)
alpha.shape, alpha

(torch.Size([32]),
 tensor([0.0057, 0.6723, 0.0398, 0.5541, 0.7173, 0.7129, 0.9773, 0.7841, 0.8830,
         0.9749, 0.1245, 0.6497, 0.0303, 0.8716, 0.5994, 0.9773, 0.7628, 0.9717,
         0.8901, 0.4728, 0.0524, 0.6432, 0.0160, 0.6545, 0.2649, 0.6737, 0.9904,
         0.0881, 0.9853, 0.0428, 0.0195, 0.5897]))

In [104]:
def right_pad_dims_to(x, t):
    padding_dims = x.ndim - t.ndim
    if padding_dims <= 0:
        return t
    return t.view(*t.shape, *((1,) * padding_dims))
alpha = right_pad_dims_to(output_ae, alpha)
alpha.shape

torch.Size([32, 1, 1])

In [107]:
noise = noise.to('cuda')
alpha = alpha.to('cuda')
output_ae = output_ae.to('cuda')

In [109]:
z_t = alpha.sqrt() * output_ae + (1-alpha).sqrt() * noise
z_t.shape

torch.Size([32, 32, 64])

In [44]:
import torch
import math
import numpy as np
from einops import repeat

def get_sampling_timesteps(self, batch, *, device, invert = False):
    times = torch.linspace(1., 0., self.sampling_timesteps + 1, device = device)
    if invert:
        times = times.flip(dims = (0,))
    times = repeat(times, 't -> b t', b = batch)
    times = torch.stack((times[:, :-1], times[:, 1:]), dim = 0)
    times = times.unbind(dim = -1)
    return times

In [79]:
sampling_timesteps = 250
batch = 32

times = torch.linspace(1., 0., sampling_timesteps + 1)
# times = times.flip(dims = (0,))
print(times.shape)
times = repeat(times, 't -> b t', b = batch)
print(times.shape)
times = torch.stack((times[:, :-1], times[:, 1:]), dim = 0)
print(times.shape)
times = times.unbind(dim = -1)
print(len(times), times[0].shape)
times[0]

torch.Size([251])
torch.Size([32, 251])
torch.Size([2, 32, 250])
250 torch.Size([2, 32])


tensor([[1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
         1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
         1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
         1.0000, 1.0000, 1.0000, 1.0000, 1.0000],
        [0.9960, 0.9960, 0.9960, 0.9960, 0.9960, 0.9960, 0.9960, 0.9960, 0.9960,
         0.9960, 0.9960, 0.9960, 0.9960, 0.9960, 0.9960, 0.9960, 0.9960, 0.9960,
         0.9960, 0.9960, 0.9960, 0.9960, 0.9960, 0.9960, 0.9960, 0.9960, 0.9960,
         0.9960, 0.9960, 0.9960, 0.9960, 0.9960]])

In [42]:
def cosine_schedule(t, start = 0, end = 1, tau = 1, clip_min = 1e-9):
    power = 2 * tau
    v_start = math.cos(start * math.pi / 2) ** power
    v_end = math.cos(end * math.pi / 2) ** power
    output = torch.cos((t * (end - start) + start) * math.pi / 2) ** power
    output = (v_end - output) / (v_end - v_start)
    return output.clamp(min = clip_min)

def log(t, eps = 1e-12):
    return torch.log(t.clamp(min = eps))

def log_snr_to_alpha(log_snr):
    alpha = torch.sigmoid(log_snr)
    return alpha

def alpha_to_shifted_log_snr(alpha, scale = 1):
    return log((alpha / (1 - alpha))).clamp(min=-15, max=15) + 2*np.log(scale).item()

def time_to_alpha(t, alpha_schedule, scale):
    alpha = alpha_schedule(t)
    shifted_log_snr = alpha_to_shifted_log_snr(alpha, scale = scale)
    return log_snr_to_alpha(shifted_log_snr)

In [73]:
alpha = [cosine_schedule(torch.tensor(t*(1e-1)), tau=1) for t in range(10)]
shifted_log_snr = [alpha_to_shifted_log_snr(a, scale=1) for a in alpha]
shifted_log_snr_to_alpha = [log_snr_to_alpha(s) for s in shifted_log_snr]

In [74]:
alpha

[tensor(1.),
 tensor(0.9755),
 tensor(0.9045),
 tensor(0.7939),
 tensor(0.6545),
 tensor(0.5000),
 tensor(0.3455),
 tensor(0.2061),
 tensor(0.0955),
 tensor(0.0245)]

In [75]:
shifted_log_snr

[tensor(15.),
 tensor(3.6855),
 tensor(2.2484),
 tensor(1.3486),
 tensor(0.6389),
 tensor(-5.9605e-08),
 tensor(-0.6389),
 tensor(-1.3486),
 tensor(-2.2484),
 tensor(-3.6855)]

In [76]:
shifted_log_snr_to_alpha

[tensor(1.0000),
 tensor(0.9755),
 tensor(0.9045),
 tensor(0.7939),
 tensor(0.6545),
 tensor(0.5000),
 tensor(0.3455),
 tensor(0.2061),
 tensor(0.0955),
 tensor(0.0245)]

In [77]:
for i in range(9):
    print(alpha[i]/alpha[i+1])

tensor(1.0251)
tensor(1.0785)
tensor(1.1393)
tensor(1.2130)
tensor(1.3090)
tensor(1.4472)
tensor(1.6763)
tensor(2.1584)
tensor(3.9021)


In [46]:
import torch
from torch.nn import functional as F
import math

# latents = torch.ones((32, 32, 64))
latents = torch.arange(1, 65).reshape(1,1,64).repeat(32,32,1).float()
# latents[0,0,:]
# F.normalize(latents, dim=-1, p=2)[0,0,:]
# torch.sum(F.normalize(latents, dim=-1)[0,0,:])
# math.sqrt(latents.shape[-1])
F.normalize(latents, dim=-1)[0,0,:] * math.sqrt(latents.shape[-1])

tensor([0.0268, 0.0535, 0.0803, 0.1070, 0.1338, 0.1605, 0.1873, 0.2140, 0.2408,
        0.2675, 0.2943, 0.3210, 0.3478, 0.3745, 0.4013, 0.4280, 0.4548, 0.4815,
        0.5083, 0.5350, 0.5618, 0.5885, 0.6153, 0.6420, 0.6688, 0.6955, 0.7223,
        0.7490, 0.7758, 0.8025, 0.8293, 0.8560, 0.8828, 0.9095, 0.9363, 0.9630,
        0.9898, 1.0165, 1.0433, 1.0700, 1.0968, 1.1235, 1.1503, 1.1770, 1.2038,
        1.2305, 1.2573, 1.2840, 1.3108, 1.3375, 1.3643, 1.3910, 1.4178, 1.4445,
        1.4713, 1.4980, 1.5248, 1.5515, 1.5783, 1.6050, 1.6318, 1.6585, 1.6853,
        1.7120])

In [47]:
# (latents / torch.sqrt(torch.sum(latents**2, dim=-1))[0,0])[0,0,:]
# torch.sqrt(torch.sum(F.normalize(latents, dim=-1, p=2)[0,0,:]**2))
torch.sqrt(torch.sum((F.normalize(latents, dim=-1)[0,0,:] * math.sqrt(latents.shape[-1]))**2))

tensor(8.)