In [1]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = "3"

In [2]:
import argparse
import random
import math
import json
import time 
import einops
import torch
import torch.nn as nn
import torch.nn.functional as F

from sklearn.preprocessing import StandardScaler
from rich import print as rprint
from typing import Dict, List
from datasets import load_dataset
from tqdm import tqdm
from torch.utils.data import TensorDataset, DataLoader, random_split
from transformers import AutoTokenizer, AutoModelForCausalLM, LogitsProcessorList
import matplotlib.pyplot as plt
from sklearn.model_selection import StratifiedShuffleSplit
import numpy as np
import seaborn as sns

torch.manual_seed(42)

<torch._C.Generator at 0x7dc00c1829b0>

In [3]:
model_name = "meta-llama/Llama-2-7b-hf"
#model_name = "EleutherAI/pythia-1b"
#model_name = "/assets/models/meta-llama-3.1-8b"
load_model = False
train = True
load_linear_model = False

In [4]:
# tokenizer
tokenizer = AutoTokenizer.from_pretrained(
    model_name,
    device_map="auto")
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

model = None
if load_model:
    model = AutoModelForCausalLM.from_pretrained(model_name,
                                                 device_map="auto")
    model.resize_token_embeddings(len(tokenizer))
    model.eval()
    # extract the embed_token layer and purge the model from memory
    w = model.model.embed_tokens.weight.data.detach().clone()


In [5]:
device = torch.device("cuda")
print("Using device: ", device)

Using device:  cuda


#### Dataset with `[X,y]` with `X=h(t)` and `y=token(t)`

#### Create the dataset

#### Load the dataset

In [6]:
model_suffix = model_name.split("/")[-1]

X_dataset = torch.load(f"data/X_dataset_{model_suffix}_top1000.pt")
y_dataset = torch.load(f"data/y_dataset_{model_suffix}_top1000.pt")
total_classes = 1000

In [None]:
X_dataset_np = X_dataset.numpy()
y_dataset_np = y_dataset.numpy()

# Initialize StratifiedShuffleSplit
splitter = StratifiedShuffleSplit(n_splits=1, test_size=0.2, random_state=42)

# Perform stratified split
for train_index, val_index in splitter.split(X_dataset_np, y_dataset_np):
    # Get train and validation data
    X_train, X_val = X_dataset[train_index], X_dataset[val_index]
    y_train, y_val = y_dataset[train_index], y_dataset[val_index]

X_train = torch.tensor(X_train)
X_val = torch.tensor(X_val)
y_train = torch.tensor(y_train)
y_val = torch.tensor(y_val)

In [8]:
# mean = X_train.mean()
# std = X_train.std()

# X_train = (X_train - mean) / std
# y_train = (y_train - mean) / std

In [None]:
print("Train Dataset Size: ", X_train.shape)
print("Val Dataset Size: ", X_val.shape)
hidden_dim = X_dataset.shape[1]

In [10]:
train_dataset = TensorDataset(X_train, y_train)
test_dataset = TensorDataset(X_val, y_val)



train_dataloader = DataLoader(train_dataset, batch_size=4096, shuffle=True)
val_dataloader = DataLoader(test_dataset, batch_size=4096)

### Linear probe for current token classification

- Derive soft labels from inverse transformations of token embeddings using the psuedoinverse of the embedding layer matrix
- Train linear model without bias using MSE loss

#### Define model

In [None]:
multiplier = 3

class LinearProbe(nn.Module):
    def __init__(self, hdim=hidden_dim, n_classes=len(tokenizer)):
        super().__init__()
        self.fc1 = nn.Linear(hdim, n_classes, bias=False)

    def forward(self, X):
        x = self.fc1(X)
        return x


def cosine_orthogonality_loss(Y, labels, lambda_intra=1.0, lambda_inter=1.0, lambda_reg_intra=1.0, lambda_reg_inter=1.0):

    labels_expanded = labels.unsqueeze(1)
    same_class_mask = labels_expanded == labels_expanded.T
    diff_class_mask = labels_expanded != labels_expanded.T

    Y_normalized = F.normalize(Y, p=2, dim=1, eps=1e-6)

    cosine_similarity = torch.matmul(
        Y_normalized, Y_normalized.T)

    inner_product = torch.matmul(Y, Y.T)

    # Intra-class loss: encourage cosine similarity close to 1
    intra_class_sims = cosine_similarity[same_class_mask]
    # intra_class_var = torch.var(intra_class_ip)
    intra_class_loss = torch.mean(1 - intra_class_sims)

    # Inter-class loss: encourage orthogonality (cosine similarity close to 0)
    inter_class_sims = cosine_similarity[diff_class_mask]
    inter_class_loss = torch.mean(inter_class_sims ** 2)

    intra_class_ip = torch.abs(inner_product[same_class_mask])
    inter_class_ip = torch.abs(inner_product[diff_class_mask])

    intra_class_variance = torch.var(intra_class_ip)
    inter_class_variance = torch.var(inter_class_ip)
    regularization_term = lambda_reg_inter*intra_class_variance + \
        lambda_reg_inter*inter_class_variance
    # Final loss with proper weighting
    loss = lambda_intra * intra_class_loss + lambda_inter * \
        inter_class_loss + regularization_term

    return loss


linear_nn = LinearProbe(hidden_dim, total_classes*multiplier)
linear_nn.to(device)

#### Train the model

In [12]:
num_epochs = 500
# Implement StepLR to decrease the learning rate by a factor of 0.5 every 50 epochs
learning_rate = 1e-4

optimizer = torch.optim.AdamW(linear_nn.parameters(), lr=learning_rate)

steplr = torch.optim.lr_scheduler.StepLR(optimizer, step_size=100, gamma=0.5)    

In [13]:
def validate():
    # Validation Epoch
    linear_nn.eval()
    total_inter_class_ip = 0
    total_intra_class_ip = 0
    total_inter_class_ip_sqr = 0
    total_intra_class_ip_sqr = 0
    total_inter_class_sim = 0
    total_intra_class_sim = 0
    num_inter_pairs = 0
    num_intra_pairs = 0
    correct = 0

    with torch.no_grad():
        for i, batch in enumerate(val_dataloader):
            X = batch[0].to(device)
            y = batch[1].to(device)
            logits = linear_nn(X)
            logits_normalized = F.normalize(logits, p=2, dim=1, eps=1e-6)

            inner_product = torch.matmul(logits, logits.T)
            cosine_similarity = torch.matmul(
                logits_normalized, logits_normalized.T)

            y_batch = y.unsqueeze(1)
            same_class_mask = (y_batch == y_batch.T)
            diff_class_mask = (y_batch != y_batch.T)

            # Compute statistics with proper masking
            intra_class_sim = cosine_similarity[same_class_mask]
            intra_class_ip = inner_product[same_class_mask]

            inter_class_sim = cosine_similarity[diff_class_mask]
            inter_class_ip = inner_product[diff_class_mask]

            # Correct calculation: Compare average intra-class similarity to average inter-class similarity
            avg_intra_class_ip = intra_class_ip.mean() if intra_class_ip.numel() > 0 else 0
            avg_inter_class_ip = inter_class_ip.mean() if inter_class_ip.numel() > 0 else 0
            correct += avg_intra_class_ip > avg_inter_class_ip

            # Sum for mean/variance computation
            total_inter_class_ip += inter_class_ip.sum()
            total_inter_class_ip_sqr += inter_class_ip.pow(2).sum()

            total_intra_class_ip += intra_class_ip.sum()
            total_intra_class_ip_sqr += intra_class_ip.pow(2).sum()

            total_inter_class_sim += inter_class_sim.sum()
            total_intra_class_sim += intra_class_sim.sum()

            num_intra_pairs += same_class_mask.sum()
            num_inter_pairs += diff_class_mask.sum()

    # Compute means safely (avoid division by zero)
    num_inter_pairs = max(num_inter_pairs, 1)  # Prevent zero division
    num_intra_pairs = max(num_intra_pairs, 1)

    mean_inter_class_sim = total_inter_class_sim / num_inter_pairs
    mean_intra_class_sim = total_intra_class_sim / num_intra_pairs

    mean_intra_class_ip = total_intra_class_ip / num_intra_pairs
    var_intra_class_ip = total_intra_class_ip_sqr / \
        num_intra_pairs - mean_intra_class_ip**2

    mean_inter_class_ip = total_inter_class_ip / num_inter_pairs
    var_inter_class_ip = total_inter_class_ip_sqr / \
        num_inter_pairs - mean_inter_class_ip**2

    # Correct Accuracy Calculation
    accuracy = correct / len(val_dataloader)  # Average across batches

    print("Validation Accuracy: ", accuracy.item())
    print(f"Inter-class similarity: {mean_inter_class_sim.item()}")
    print(f"Intra-class similarity: {mean_intra_class_sim.item()}")
    print(
        f"Inter-class inner product std dev: {torch.sqrt(var_inter_class_ip).item()}")
    print(
        f"Intra-class inner product std dev: {torch.sqrt(var_intra_class_ip).item()}")
    
    return mean_inter_class_ip.item(), mean_intra_class_ip.item(), var_inter_class_ip.item(), var_intra_class_ip.item()

In [14]:
if train and load_linear_model:
    linear_nn.load_state_dict(torch.load(
        f"saved_models/linear_probe_contrastive_c{total_classes}.pth"))
    print("Loaded Linear Model")
    validate()

In [None]:
train_losses = []
val_accs = []

if train:
    for epoch in range(num_epochs):

        # Training epoch
        linear_nn.train()
        epoch_train_loss = 0.0

        for i, batch in enumerate(train_dataloader):
            X = batch[0].to(device)
            y = batch[1].to(device)
            outputs = linear_nn(X)
            loss = cosine_orthogonality_loss(
                outputs, y, lambda_intra=400, lambda_inter=400, lambda_reg_inter=0.1, lambda_reg_intra=0.6)  # Compute loss
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
            epoch_train_loss += loss.item()
            
        steplr.step()
        avg_train_loss = epoch_train_loss / len(train_dataloader)
        train_losses.append(avg_train_loss)
        if epoch % 5 == 0:
            print(f"Epoch: {epoch} with train loss:", avg_train_loss)
            val_accs.append(validate())

In [None]:
if train:
    mean_inter_class_ip, mean_intra_class_ip, var_inter_class_ip, var_intra_class_ip = zip(*
                                                                                           val_accs)
    plt.figure(figsize=(10, 6))
    plt.plot(range(1, num_epochs + 1), train_losses, label='Training Loss')
    plt.plot(range(1, num_epochs//5 + 1),
             mean_inter_class_ip, label='Inter-class IP')
    plt.plot(range(1, num_epochs//5 + 1),
             mean_intra_class_ip, label='Intra-class IP')
    plt.plot(range(1, num_epochs//5 + 1), var_inter_class_ip,
             label='Inter-class IP Variance')
    plt.plot(range(1, num_epochs//5 + 1), var_intra_class_ip,
             label='Intra-class IP Variance')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.title('Training and Validation Loss')
    plt.legend()
    plt.show()
    print("Mean Inter-class IP: ",
          mean_inter_class_ip[-1], mean_inter_class_ip[-2])
    print("Mean Intra-class IP: ",
          mean_intra_class_ip[-1], mean_intra_class_ip[-2])
    print("Variance Inter-class IP: ",
          var_inter_class_ip[-1], var_inter_class_ip[-2])
    print("Variance Intra-class IP: ",
          var_intra_class_ip[-1], var_intra_class_ip[-2])
    print("Train Loss: ", train_losses[-1], train_losses[-2])

Combine class means with the original embeddings and compute the inner product similarity matrix that should transform hidden states to one-hot encodings

In [None]:
if True:
    linear_nn.eval()

    unique_classes = torch.unique(y_dataset)
    class_means = torch.zeros(len(tokenizer), multiplier*total_classes)
    cosine_similarity_scores = []

    for class_idx in tqdm(unique_classes):
        # Find all samples of the same class
        class_mask = y_dataset == class_idx
        X_class = X_dataset[class_mask]

        with torch.no_grad():
            logits = linear_nn(X_class.to(device))

        # Normalize logits
        logits = F.normalize(logits, p=2, dim=1, eps=1e-6)

        # Compute cosine similarity matrix
        cosine_similarity = torch.matmul(logits, logits.T)
        # get average cosine similarity
        cosine_similarity_scores.append(cosine_similarity.mean().item())

        class_means[class_idx] = logits.mean(dim=0)

    weights = linear_nn.fc1.weight.data
    final_weights = torch.matmul(class_means.to(weights.device), weights)
    linear_nn_final = LinearProbe(hidden_dim, len(tokenizer)).to(device)
    with torch.no_grad():
        linear_nn_final.fc1.weight.copy_(final_weights)

In [None]:
# Plot histogram of cosine similarity scores
plt.figure(figsize=(10, 6))
plt.hist(cosine_similarity_scores, bins=50)
plt.xlabel('Cosine Similarity')
plt.ylabel('Frequency')
plt.title('Cosine Similarity Scores')
plt.show()

In [None]:

correct = 0
# Let's calculate the accuracy on the validation set
for i, batch in enumerate(val_dataloader):
    with torch.no_grad():
        X = batch[0].to(device)
        y = batch[1].to(device)
        outputs = linear_nn_final(X)
        labels = outputs.argmax(dim=1)

        correct += (labels == y).sum()

print("Validation Accuracy: ", correct.item()/X_val.shape[0])

In [None]:
if train:
    linear_nn.eval()

    unique_classes = torch.unique(y_val)

    max_values = []
    min_values = []

    for i, batch in enumerate(val_dataloader):
        X = batch[0].to(device)
        y = batch[1].to(device)
        with torch.no_grad():
            logits = linear_nn_final(X)
        max_values.append(logits.max(dim=1).values)
        min_values.append(logits.min(dim=1).values)
        
    max_values = torch.cat(max_values)
    min_values = torch.cat(min_values)

    print(f"Mean Max: {max_values.mean().item()} | Std Dev: {max_values.std().item()}")
    print(f"Mean Min: {min_values.mean().item()} | Std Dev: {min_values.std().item()}")    
    
#     # # Divide linear model weights by mean_max
#     # with torch.no_grad():
#     #     linear_nn.fc1.weight.div_(mean_max)


In [None]:
if train:
    torch.save(linear_nn.state_dict(), f'saved_models/linear_probe_contrastive_c{total_classes}_bs4096_400_400_0.1.pth')
    
    
raise

### Load model

##### Some qualitative checks on the linear probe

In [None]:

linear_nn = LinearProbe(hidden_dim, total_classes).to(device)
model_path = f'saved_models/linear_probe_contrastive_c{total_classes}.pth'
linear_nn.load_state_dict(torch.load(model_path))

In [None]:
linear_nn.eval()

max_values = []
min_values = []

for i, batch in enumerate(val_dataloader):
    X = batch[0].to(device)
    y = batch[1].to(device)
    with torch.no_grad():
        logits = linear_nn(X)
    max_values.append(logits.max(dim=1).values)
    min_values.append(logits.min(dim=1).values)
    
max_values = torch.cat(max_values)
min_values = torch.cat(min_values)

print(f"Mean Max: {max_values.mean().item()} | Std Dev: {max_values.std().item()}")
print(f"Mean Min: {min_values.mean().item()} | Std Dev: {min_values.std().item()}")

In [None]:
# find three sigma outliers
outliers = (max_values > max_values.mean() + 3 * max_values.std()) | (min_values < min_values.mean() - 3 * min_values.std())
outliers = outliers.cpu()
len(torch.unique(y_val[outliers]))

Get accuracy on train dataset

In [None]:
linear_nn.eval()

correct = 0
# Let's calculate the accuracy on the train set
for i, batch in enumerate(train_dataloader):
    with torch.no_grad():
        X = batch[0].to(device)
        y = batch[1].to(device)

        outputs = linear_nn(X)
        labels = outputs.argmax(dim=1)

        correct += (labels == y).sum()

print("Train Accuracy: ", correct.item()/X_train.shape[0])

Get accuracy on val dataset

In [None]:

correct = 0
# Let's calculate the accuracy on the train set
for i, batch in enumerate(val_dataloader):
    with torch.no_grad():
        X = batch[0].to(device)
        y = batch[1].to(device)

        outputs = linear_nn(X)
        labels = outputs.argmax(dim=1)

        correct += (labels == y).sum()

print("Validation Accuracy: ", correct.item()/X_val.shape[0])

In [None]:
top2_values = []
for i, batch in enumerate(val_dataloader):
    with torch.no_grad():
        X = batch[0].to(device)
        y = batch[1].to(device)

        outputs = linear_nn(X)
        # Store diff between top 2 values
        topk_values, _ = outputs.topk(2)
        top2_values.append(topk_values[:, :2])

    
top2_values = torch.cat(top2_values).cpu()

diff_values = top2_values[:, 0] - top2_values[:, 1]

In [None]:
max_values = top2_values[:, 0]
mean = max_values.mean()
std_dev = max_values.std()
zscores = (max_values - mean) / std_dev

print("Mean: ", mean.item())
print("Standard Deviation: ", std_dev.item())


In [None]:

mean = diff_values.mean()
std_dev = diff_values.std()
zscores = (diff_values - mean) / std_dev

# remove outliers
diff_values = diff_values[zscores < 3]

# plot histogram of diff values
plt.figure(figsize=(10, 6))
plt.hist(diff_values.numpy(), bins=1000)
plt.xlabel('Difference between top 2 logits')
plt.ylabel('Frequency')
plt.title('Difference between top 2 logits Distribution')
plt.show()

Some manual qualititive checks

In [None]:
idx = 46
X = X_val[idx].unsqueeze(0).to(device)
with torch.no_grad():
    logits = linear_nn(X)


# plots
plt.figure(figsize=(10, 6))
plt.hist(logits.cpu().numpy().flatten(), bins=100)
plt.xlabel('Logits')
plt.ylabel('Frequency')
plt.title('Logits Distribution')
plt.show()

### Build watermark matrix

In [None]:


rng = torch.Generator(device=device)

vocab_size = len(tokenizer)
gamma = 0.50
hash_key = 15485863


def prf_lookup(input_ids):
    return hash_key * input_ids[-1:].sum().item()


def get_partition(input_ids):
    prf_key = prf_lookup(input_ids)
    rng.manual_seed(prf_key % (2**64 - 1))

    greenlist_size = int(vocab_size * gamma)
    vocab_permutation = torch.randperm(
        vocab_size, device=device, generator=rng)
    greenlist_ids = vocab_permutation[:greenlist_size].to("cpu")
    redlist_ids = vocab_permutation[greenlist_size:].to("cpu")
    return greenlist_ids, redlist_ids

# values x keys
watermark_matrix = torch.zeros(len(tokenizer), len(tokenizer)).to(device)


for i in tqdm(range(len(tokenizer))):
    greenlist_ids, redlist_ids = get_partition(torch.tensor([i]))
    watermark_matrix[greenlist_ids, i] = 2.0
    watermark_matrix[redlist_ids, i] = 0

In [None]:
idx = 9
X = X_val[idx].unsqueeze(0).to(device)
y = y_val[idx].unsqueeze(0).to(device)
l1_criterion = nn.L1Loss(reduce=False)
with torch.no_grad():
    logits = linear_nn(X).squeeze(0)
    ans = torch.matmul(watermark_matrix, logits)
    one_hot = torch.zeros_like(logits)
    one_hot[y] = 1
    gold = torch.matmul(watermark_matrix, one_hot)
    # Get average difference between logits and gold
    diff = l1_criterion(logits, one_hot)
    max_diff_index = torch.argmax(diff)
    print(logits[max_diff_index])
    print(one_hot[max_diff_index])
    print(logits.topk(5))
    print(one_hot.topk(5))   

In [None]:
    
test_text = "The cat sat on the "
inputs = tokenizer(test_text, return_tensors="pt").to(model.device)
with torch.no_grad():
    output = model.generate(
        **inputs, return_dict_in_generate=True,  max_length=10)



In [None]:
seq = output.sequences[0]
tokenizer.decode(seq, skip_special_tokens=True)

Also do the analysis using the watermark matrix (what i mean when i say i want the logits to be close)

In [None]:
watermark_matrix = torch.load("data/watermark_matrix_gamma0.50_simple1_key15485863.pt")

# dummy logits where everything is a small positive value
# probe_outputs = torch.ones((10, 50277), dtype=torch.float32)/10
probe_outputs = outputs
wm_predicted = einops.einsum(probe_outputs.cpu(), watermark_matrix.T, "b i, i j -> b j")

print("Watermark Delta Prediction: ", wm_predicted.shape)

In [78]:
greenlist_ids, redlist_ids = get_partition(torch.tensor([187]))

In [None]:
print("Mean Green List: ", wm_predicted[0, greenlist_ids].mean().item())
print("Mean Red List: ", wm_predicted[0, redlist_ids].mean().item())

In [None]:
plt.figure(figsize=(4, 3))
sns.histplot(wm_predicted[0, greenlist_ids].cpu(), bins=30, kde=True, color="green")
# sns.histplot(wm_predicted[0, redlist_ids].cpu(), bins=30, kde=True, color="red")
plt.show()

##### Magnitude of the Logits :- Clean Analysis of the Logits

##### Check how biased these logits are using this loss: f2 + f3 loss

In [None]:
plt.figure(figsize=(4, 3))
sns.histplot(wm_predicted[7].cpu(), bins=30, kde=True)
plt.show()

In [62]:
greenlist_ids, redlist_ids = get_partition(torch.tensor([y[6]]))

In [None]:
plt.figure(figsize=(4, 3))
sns.histplot(wm_predicted[6, greenlist_ids].cpu(), bins=30, kde=True, color="green")
sns.histplot(wm_predicted[6, redlist_ids].cpu(), bins=30, kde=True, color="red")
plt.show()

##### Check the MSE loss on the watermark logits 

In [15]:
watermark_matrix = torch.load("data/watermark_matrix_gamma0.50_simple1_key15485863.pt")
# watermark_matrix = watermark_matrix.to(device)

In [None]:
mse_loss = nn.MSELoss()
linear_nn.eval()

total_loss = 0
with torch.no_grad():
    for i, batch in tqdm(enumerate(val_dataloader)):
        X = batch[0]
        y = batch[1]

        y_onehot = F.one_hot(y, num_classes=len(tokenizer)).to(torch.float32)
        y_onehot = y_onehot

        outputs = linear_nn(X.to(device))
        outputs = torch.sigmoid(outputs).to("cpu")

        wm_predicted = einops.einsum(outputs, watermark_matrix.T, "b i, i j -> b j")
        wm_gt = watermark_matrix[:, y].T

        loss = mse_loss(wm_predicted, wm_gt)
        total_loss += loss.item()

        if i >= 100:
            break

        #? to check if matrix mult is okay
        # wm_generated = einops.einsum(y_onehot, watermark_matrix.T, "b i, i j -> b j")
        # wm_gt = watermark_matrix[:, y].T
        # diff = wm_generated - wm_gt
        # print(diff.sum())

# len(val_dataloader) * val_dataloader.batch_size

In [None]:
avg_val_loss = total_loss/100
print(avg_val_loss)

### Build the watermark matrix

Generate the watermark matrix and save it

In [None]:
print("Device: ", device)

In [77]:
# the column i of this matrix should be the green-red split for the column i.
rng = torch.Generator(device=device)

vocab_size = len(tokenizer)
gamma = 0.50
hash_key = 15485863

def prf_lookup(input_ids):
    return hash_key * input_ids[-1:].sum().item()

def get_partition(input_ids):
    prf_key = prf_lookup(input_ids)
    rng.manual_seed(prf_key % (2**64 - 1))

    greenlist_size = int(vocab_size * gamma)
    vocab_permutation = torch.randperm(vocab_size, device=device, generator=rng)
    greenlist_ids = vocab_permutation[:greenlist_size].to("cpu")
    redlist_ids = vocab_permutation[greenlist_size:].to("cpu")
    return greenlist_ids, redlist_ids

In [None]:
watermark_matrix = torch.zeros(len(tokenizer), len(tokenizer))

for i in tqdm(range(len(tokenizer))):
    greenlist_ids, redlist_ids = get_partition(torch.tensor([i]))
    watermark_matrix[greenlist_ids, i] = 1.0
    watermark_matrix[redlist_ids, i] = -1.0

In [9]:
torch.save(watermark_matrix, "data/watermark_matrix_gamma0.50_simple1_key15485863.pt")

Load the watermark matrix

In [12]:
watermark_matrix = torch.load("data/watermark_matrix_gamma0.50_simple1_key15485863.pt")

### Two Hidden Layer NN with second matrix frozen as the watermark matrix

Things tried

- Regression loss between the multi-hot encoded vectors and logits (does not work well).
- Formulate the problem as multi-label classification problem. `saved as config 2`
- Loss as the cosine similarity between the outputs of the model and watermark logits (make it +1 and -1).

TODO:

- [] Capped ReLU.
- [] Sparsity in the intermediate representation.

In [13]:
class WatermarkNetwork(nn.Module):
    def __init__(self, hdim=2048, vocab_size=len(tokenizer), watermark_matrix=None):
        super().__init__()
        self.fc1 = nn.Linear(hdim, vocab_size)
        self.act1 = nn.ReLU()
        self.fc2 = nn.Linear(vocab_size, vocab_size)

        # set the watermark matrix as the second layer
        self.fc2.weight = nn.Parameter(watermark_matrix, requires_grad=False)
        self.fc2.bias = nn.Parameter(torch.zeros(vocab_size), requires_grad=False)

    def forward(self, x):
        x = self.act1(self.fc1(x))
        return self.fc2(x)

In [None]:
wmNet = WatermarkNetwork(watermark_matrix=watermark_matrix)
wmNet.to(device)

In [15]:
# Will try training the neural network with a regression loss on the watermark outputs
batch_size = 64
learning_rate = 1e-3

train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=32)

# freeze the second layer
wmNet.fc2.weight.requires_grad = False
wmNet.fc2.bias.requires_grad = False

# criterion = nn.MSELoss()
criterion = nn.BCEWithLogitsLoss()
optimizer = torch.optim.AdamW(wmNet.parameters(), lr=learning_rate)

In [None]:
for name, param in wmNet.named_parameters():
    print(f"{name}: requires_grad={param.requires_grad}")

In [None]:
# Count the total number of parameters
total_params = sum(p.numel() for p in wmNet.parameters() if p.requires_grad)
print("Total number of trainable parameters:", total_params)

In [None]:
num_epochs = 1

for epoch in range(num_epochs):
    for i, batch  in tqdm(enumerate(train_dataloader)):
        
        X = batch[0].to(device)
        y = batch[1].to(device)

        output = wmNet(X)
        targets = watermark_matrix[:, y.to("cpu")]
        targets = targets.permute(1, 0).to(device)

        loss = criterion(output, targets)

        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        if i % 1000 == 0:
            print(f'Step [{i + 1}/{len(train_dataloader)}], Loss: {loss.item():.4f}')

In [44]:
torch.save(wmNet.state_dict(), 'saved_models/wm_config2.pth')

### Generate some text and (1) 