In [None]:
import torch
from separability import Model
from separability.texts import prepare
from tqdm import tqdm

In [None]:
m = Model("facebook/galactica-125m", 1000, dtype="fp16")

dataset, label, skip = prepare("pile")

In [None]:
activations = []
i = 0
for data in dataset:
    text = data[label]
    #print(i, text[:50])
    inpt, attn, ff, outpt = m.get_text_activations(text, limit=1000)
    activations.append(inpt)
    i += 1
    if i > 100:
        break

print(activations)

In [None]:
# construct batched dataset from activations List[Tensor]
inputs = []
for a in activations:
    for b in a:
        inputs.append(b)
inputs = torch.stack(inputs).to(dtype=torch.float32)
print(inputs.shape)

dataset = torch.utils.data.TensorDataset(inputs)


# Define basic torch mlp model
class AutoEncoder(torch.nn.Module):
    def __init__(self, input_dim, encoding_dim):
        super(AutoEncoder, self).__init__()

        # Define the encoder
        self.encoder = torch.nn.Sequential(
            torch.nn.Linear(input_dim, encoding_dim),
        )

        # Define the decoder
        self.decoder = torch.nn.Sequential(
            torch.nn.Linear(encoding_dim, input_dim),
        )

    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x

def train_autoencoder(
            dim,
            loss_fn=None,
            init_lr=1e-3,
            n_epochs=50,
            batch_size=64
        ):
    ae = AutoEncoder(input_dim=m.cfg.d_model, encoding_dim=dim)
    ae = ae.to(device='cuda', dtype=torch.float32)

    # Define the loss function and optimizer
    if loss_fn is None:
        loss_fn = torch.nn.MSELoss()
        #loss_fn = torch.nn.KLDivLoss()
        #loss_fn = torch.nn.L1Loss()

    # start training
    for epoch in (pbar := tqdm(range(n_epochs))):
        optimizer = torch.optim.Adam(
            ae.parameters(), lr=init_lr/(epoch+1)
        )
        
        for [batch] in torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True): 
            # calculate loss
            outputs = ae(batch)
            loss = loss_fn(outputs, batch)
            
            # zero out old gradients
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        
        pbar.set_postfix(loss=loss.item())
    
    for inpt in inputs[:1]: 
        output = ae(inpt)
        print( inpt[:6].detach().cpu().to(torch.float16).numpy() )
        print( output[:6].detach().cpu().to(torch.float16).numpy() )
  
for dim in [768, 512, 420, 360, 256, 128, 64, 32]: 
    train_autoencoder(dim) 
