In [1]:
import pickle
import torch
from torch.nn.functional import cosine_similarity

In [2]:
intermediate_3b_18 = pickle.load(open("tokenized/intermediate_llama_3b_18.pkl", 'rb'))
intermediate_10_1b = pickle.load(open("tokenized/first_2000/intermediate_llama_1b_10.pkl", 'rb'))

In [3]:
print(len(intermediate_3b_18))

2000


In [4]:
intermediate_3b_18_all = []
intermediate_10_1b_all = []
#for i in range(len(intermediate_3b_18)):
for i in range(1000):
    intermediate_3b_18_all.extend(intermediate_3b_18[i])
    intermediate_10_1b_all.extend(intermediate_10_1b[i])
intermediate_3b_18_all = torch.stack(intermediate_3b_18_all)
intermediate_10_1b_all = torch.stack(intermediate_10_1b_all)

In [5]:
print(intermediate_3b_18_all.shape)
print(intermediate_10_1b_all.shape)

torch.Size([448340, 3072])
torch.Size([448340, 2048])


In [6]:
del intermediate_3b_18
del intermediate_10_1b

In [7]:
def split_data(source, target, train_ratio=0.8):
    # Generate random permutation of indices
    indices = torch.randperm(len(source))

    # Calculate the split index
    split_idx = int(train_ratio * len(source))

    # Split indices into train and validation sets
    train_indices = indices[:split_idx]
    val_indices = indices[split_idx:]

    # Use the indices to split the data
    source_train = source[train_indices]
    target_train = target[train_indices]
    source_val = source[val_indices]
    target_val = target[val_indices]

    del source, target

    return source_train, target_train, source_val, target_val

In [8]:
from torch.utils.data import Dataset, DataLoader
class MappingDataset(Dataset):
    def __init__(self, source, target, device):
        self.source = source.to(device)
        self.target = target.to(device)
    def __len__(self):
        return len(self.source)
    def __getitem__(self, idx):
        return self.source[idx], self.target[idx]

In [23]:
s_train, t_train, s_val, t_val = split_data(intermediate_3b_18_all, intermediate_10_1b_all)
train_loader = DataLoader(MappingDataset(s_train, t_train, 'mps'), batch_size=128, shuffle=True)
val_loader = DataLoader(MappingDataset(s_val, t_val, 'mps'), batch_size=128, shuffle=True)

In [24]:
del s_train, t_train, s_val, t_val

In [25]:
print(len(train_loader))
print(len(val_loader))

2803
701


In [26]:
import torch
import torch.nn as nn

class SimpleEncoderDecoder(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(SimpleEncoderDecoder, self).__init__()
        self.encoder = nn.Linear(input_size, hidden_size)
        self.decoder = nn.Linear(hidden_size, output_size)
        self.relu = nn.ReLU()

    def forward(self, x):
        encoded = self.relu(self.encoder(x))
        decoded = self.decoder(encoded)
        return decoded

In [36]:
import torch
import torch.nn as nn

class denseModel(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(denseModel, self).__init__()
        layers = [
            nn.Linear(input_size, hidden_size),
            nn.GELU(),
            nn.Linear(hidden_size, hidden_size),
            nn.GELU(),
            nn.Linear(hidden_size, output_size)
        ]
        self.model = nn.Sequential(*layers)
    def forward(self, x):
        return self.model(x)



In [13]:
class CustomLoss(nn.Module):
    def __init__(self, model, lambda_reg=0.01):
        super(CustomLoss, self).__init__()
        self.model = model
        self.lambda_reg = lambda_reg
    def forward(self, predictions, targets):
        mse = nn.MSELoss(reduction='sum')
        mse_loss = mse(predictions, targets)
        reg_loss = self.lambda_reg * torch.norm(self.model.decoder.weight, p=1)
        return mse_loss + reg_loss



In [45]:
# Training
import torch.optim as optim
from tqdm import tqdm
#model = SimpleEncoderDecoder(3072, 1024, 2048)
model = denseModel(3072, 1024, 2048)
model.to('mps')
optimizer = optim.Adam(model.parameters(), lr=0.001)
#criterion = CustomLoss(model, 0.01)
#criterion = nn.MSELoss()
criterion = nn.CosineEmbeddingLoss()
epochs = 10

In [46]:
import copy

In [47]:
prev_val_loss = float('inf')
patience_counter = 0
train_losses = []
val_losses = []
best_model = None
for epoch in range(epochs):
    model.train()
    train_loss = 0
    for i, data in enumerate(tqdm(train_loader)):
        optimizer.zero_grad()
        inputs, targets = data
        outputs = model(inputs)
        loss = criterion(outputs, targets, torch.ones(len(inputs)).to('mps'))
        loss.backward()
        optimizer.step()
        train_loss += loss.item()
    train_loss = train_loss / (len(train_loader))
    train_losses.append(train_loss)
    print(f"Epoch {epoch}, Train Loss: {train_loss}")
    model.eval()
    val_loss = 0
    with torch.no_grad():
        for i, data in enumerate(val_loader):
            inputs, targets = data
            outputs = model(inputs)
            loss = criterion(outputs, targets, torch.ones(len(inputs)).to('mps'))
            val_loss += loss.item()
        val_loss = val_loss / (len(val_loader))
        val_losses.append(val_loss)
        print(f"Epoch {epoch}, Val Loss: {val_loss}")
        if val_loss < prev_val_loss:
            prev_val_loss = val_loss
            patience_counter = 0
            # set the best model
            best_model = copy.deepcopy(model)
        else:
            patience_counter += 1
        if patience_counter > 5:
            print("Early stopping")
            break

100%|██████████| 2803/2803 [00:14<00:00, 187.69it/s]


Epoch 0, Train Loss: 0.16981781875602356
Epoch 0, Val Loss: 0.14800141904816647


100%|██████████| 2803/2803 [00:14<00:00, 192.64it/s]


Epoch 1, Train Loss: 0.13889075032077514
Epoch 1, Val Loss: 0.13974080401633504


100%|██████████| 2803/2803 [00:14<00:00, 190.55it/s]


Epoch 2, Train Loss: 0.1332511874344406
Epoch 2, Val Loss: 0.13483011496177583


100%|██████████| 2803/2803 [00:14<00:00, 191.18it/s]


Epoch 3, Train Loss: 0.1299610064151327
Epoch 3, Val Loss: 0.132697249838442


100%|██████████| 2803/2803 [00:14<00:00, 191.92it/s]


Epoch 4, Train Loss: 0.12763793444340032
Epoch 4, Val Loss: 0.1300234413478582


100%|██████████| 2803/2803 [00:14<00:00, 191.49it/s]


Epoch 5, Train Loss: 0.12590219139723194
Epoch 5, Val Loss: 0.12909492772715578


100%|██████████| 2803/2803 [00:14<00:00, 192.90it/s]


Epoch 6, Train Loss: 0.12450952401841235
Epoch 6, Val Loss: 0.12833930617902486


100%|██████████| 2803/2803 [00:14<00:00, 191.98it/s]


Epoch 7, Train Loss: 0.12338638802723846
Epoch 7, Val Loss: 0.1265983259180473


100%|██████████| 2803/2803 [00:14<00:00, 192.76it/s]


Epoch 8, Train Loss: 0.12244339924421303
Epoch 8, Val Loss: 0.1259966637944529


100%|██████████| 2803/2803 [00:14<00:00, 192.92it/s]


Epoch 9, Train Loss: 0.12164415936628666
Epoch 9, Val Loss: 0.12504676706057982


In [48]:
print(prev_val_loss)

0.12504676706057982


In [62]:
for i, d in enumerate(train_loader):
    inputs, targets = d
    outputs = best_model(inputs)
    print(inputs[10])
    print(targets[10])
    cos = nn.CosineSimilarity(dim=-1, eps=1e-8)
    print(cos(outputs[10], targets[10]))
    print(torch.linalg.vector_norm(outputs[10] - targets[10]))
    break

tensor([ 0.0276, -0.1827, -1.4195,  ...,  0.0373,  0.0581, -0.2269],
       device='mps:0')
tensor([-0.0760, -0.0870, -0.5498,  ...,  0.2267, -0.0715,  0.0630],
       device='mps:0')
tensor(0.9060, device='mps:0', grad_fn=<SumBackward1>)
tensor(1841.8755, device='mps:0', grad_fn=<LinalgVectorNormBackward0>)


In [63]:
torch.save(best_model, "simple_encoder_decoder_3_layers_cosine.pth")