In [1]:
# set cuda visible devices
def is_notebook() -> bool:
    try:
        shell = get_ipython().__class__.__name__
        if shell == 'ZMQInteractiveShell':
            return True   # Jupyter notebook or qtconsole
        elif shell == 'TerminalInteractiveShell':
            return False  # Terminal running IPython
        else:
            return False  # Other type (?)
    except NameError:
        return False      # Probably standard Python interpreter

import os
if is_notebook():
    os.environ["CUDA_VISIBLE_DEVICES"] = "0" #"1"
    # os.environ['CUDA_LAUNCH_BLOCKING']="1"
    # os.environ['TORCH_USE_CUDA_DSA'] = "1"

import matplotlib 
if not is_notebook():
    matplotlib.use('Agg')

In [None]:
# TODO: fix gradient accumulation to work for batched loss, by computing logits  for 
# virtual batch without grad, then iterating over mini-batches and replacing the logits

In [2]:
import os
os.chdir("/nas/ucb/oliveradk/diverse-gen")

In [3]:
# ok so the basic setup would be pretraining on the whole thing (already done) 
# then finetuning on the source and target data using the standard losses

In [4]:
import torch
from torch.utils.data import random_split
from torch import nn
import torch.nn.functional as F
from torch.utils.data import Dataset

from transformers import AutoConfig
from transformers import AutoModelForSequenceClassification
from transformers import AutoTokenizer

from losses.divdis import DivDisLoss 
from losses.divdis import DivDisLoss
from losses.ace import ACELoss
from losses.conf import ConfLoss
from losses.dbat import DBatLoss
from losses.smooth_top_loss import SmoothTopLoss
from losses.loss_types import LossType

from models.backbone import MultiHeadBackbone
from utils.utils import batch_size, to_device


  from .autonotebook import tqdm as notebook_tqdm


In [5]:
from dataclasses import dataclass
from datetime import datetime
@dataclass 
class Config: 
    loss_type: LossType = LossType.TOPK
    lr=1e-5#2e-5
    weight_decay=1e-2
    epochs=5
    scheduler="cosine"
    frac_warmup=0.05
    num_epochs=5
    batch_size=2
    gradient_accumulation_steps=16 # 2 * 32
    seed=42
    max_length=1024
    dataset_len=64
    binary=True
    heads=2
    source_weight=1.0
    aux_weight=1.0
    mix_rate_lower_bound=0.1
    use_group_labels=False
    mixed_precision=False
    device="cuda" if torch.cuda.is_available() else "cpu"
    exp_dir=f"output/mtd/{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}"


In [6]:
conf = Config()

# Model and Dataset

In [7]:
model_path = "oliverdk/codegen-350M-mono-measurement_pred"


config = AutoConfig.from_pretrained(
    model_path,
    trust_remote_code=True
)
pretrained_model = AutoModelForSequenceClassification.from_pretrained(
    model_path,
    config=config,
    trust_remote_code=True, 
    device_map=conf.device
)



In [8]:
tokenizer = AutoTokenizer.from_pretrained(
    model_path,
    trust_remote_code=True, 
    padding_side="left"
)
tokenizer.pad_token = tokenizer.eos_token

# dataset and model

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [9]:
from datasets import load_dataset

dataset = load_dataset("redwoodresearch/diamonds-seed0")

In [10]:
class DiamondsDataset(Dataset):
    def __init__(self, dataset, tokenizer, max_length=1024):
        self.dataset = dataset
        self.tokenizer = tokenizer
        self.max_length = max_length
    
    def __len__(self):
        return len(self.dataset)
    
    def __getitem__(self, idx):
        item = self.dataset[idx]

        # tokenize the text
        encoding = self.tokenizer(
            item['text'],
            max_length=self.max_length,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )

        # convert label to tensor
        label = torch.tensor(item['is_correct'])
        all_measurements = torch.tensor(all(item['measurements'])).float()
        group_labels = torch.stack((label, all_measurements))
        encoding = {k:v.squeeze(0) for k, v in encoding.items()}

        return encoding, label, group_labels

In [11]:
if conf.dataset_len is not None:
    for k, subset in dataset.items():
        dataset[k] = subset.select(range(conf.dataset_len))

In [12]:
# source (is clean)
val_frac = 0.2
source_data = dataset["train"].filter(lambda x: x["is_clean"])
splits = source_data.train_test_split(train_size=1-val_frac, test_size=val_frac, seed=conf.seed)
dataset["source_train"] = splits['train']
dataset["source_val"] = splits['test']

# target (is not clean)
target_data = dataset["train"].filter(lambda x: not x["is_clean"])
dataset["target_train"] = target_data
dataset["target_val"] = dataset["train_for_val"]

# test (validation)
dataset["test"] = dataset["validation"]

# remove train and train_for_val
dataset.pop("train")
dataset.pop("train_for_val")
dataset.pop("validation")


Dataset({
    features: ['text', 'is_correct', 'is_clean', 'measurements', 'difficulty'],
    num_rows: 64
})

In [13]:
def is_tampering(x):
    return not x["is_correct"] and any(x["measurements"])

def fake_positive(x):
    return not x["is_correct"] and all(x["measurements"])

def split_tampering_rate(dataset):
    tampering_rate = len(dataset.filter(is_tampering)) / len(dataset)
    return tampering_rate
def split_fake_positive_rate(dataset):
    fake_positive_rate = len(dataset.filter(fake_positive)) / len(dataset)
    return fake_positive_rate
source_train_tampering_rate = split_tampering_rate(dataset["source_train"])
target_train_tampering_rate = split_tampering_rate(dataset["target_train"])
source_val_tampering_rate = split_tampering_rate(dataset["source_val"])
target_val_tampering_rate = split_tampering_rate(dataset["target_val"])
test_tampering_rate = split_tampering_rate(dataset["test"])   

source_train_fake_positive_rate = split_fake_positive_rate(dataset["source_train"])
target_train_fake_positive_rate = split_fake_positive_rate(dataset["target_train"])
source_val_fake_positive_rate = split_fake_positive_rate(dataset["source_val"])
target_val_fake_positive_rate = split_fake_positive_rate(dataset["target_val"])
test_fake_positive_rate = split_fake_positive_rate(dataset["test"])

print(f"source train: tampering {source_train_tampering_rate:.2f}, fake positive {source_train_fake_positive_rate:.2f}")
print(f"target train: tampering {target_train_tampering_rate:.2f}, fake positive {target_train_fake_positive_rate:.2f}")
print(f"source val: tampering {source_val_tampering_rate:.2f}, fake positive {source_val_fake_positive_rate:.2f}")
print(f"target val: tampering {target_val_tampering_rate:.2f}, fake positive {target_val_fake_positive_rate:.2f}")
print(f"test: tampering {test_tampering_rate:.2f}, fake positive {test_fake_positive_rate:.2f}")


source train: tampering 0.00, fake positive 0.00
target train: tampering 0.29, fake positive 0.12
source val: tampering 0.00, fake positive 0.00
target val: tampering 0.53, fake positive 0.45
test: tampering 0.39, fake positive 0.31


In [14]:
source_train_ds = DiamondsDataset(dataset["source_train"], tokenizer, conf.max_length)
source_val_ds = DiamondsDataset(dataset["source_val"], tokenizer, conf.max_length)
target_train_ds = DiamondsDataset(dataset["target_train"], tokenizer, conf.max_length)
target_val_ds = DiamondsDataset(dataset["target_val"], tokenizer, conf.max_length)
test_ds = DiamondsDataset(dataset["test"], tokenizer, conf.max_length)


In [15]:
from torch.utils.data import DataLoader
dataloader = DataLoader(source_train_ds, batch_size=conf.batch_size)
x = next(iter(dataloader))

In [16]:
x[0], x[1], x[2]

({'input_ids': tensor([[50256, 50256, 50256,  ...,   685, 42848,   198],
          [50256, 50256, 50256,  ...,   685, 42848,   198]]),
  'attention_mask': tensor([[0, 0, 0,  ..., 1, 1, 1],
          [0, 0, 0,  ..., 1, 1, 1]])},
 tensor([True, True]),
 tensor([[1., 1.],
         [1., 1.]]))

In [17]:
class MeasurementPredBackbone(nn.Module):
    def __init__(self, pretrained_model):
        super().__init__()
        self.pretrained_model = pretrained_model
    
    def forward(self, x):
        out = self.pretrained_model.base_model(x['input_ids'], attention_mask=x['attention_mask'])
        embd = out.last_hidden_state[:, -1, :]
        return embd


In [18]:
# pred_model = MeasurementPredBackbone(pretrained_model).to('cpu')
# out = pred_model(x[0])
# out.shape

# Train

In [19]:
from transformers import get_scheduler


pred_model = MeasurementPredBackbone(pretrained_model).to(conf.device)
net = MultiHeadBackbone(pred_model, n_heads=2, feature_dim=1024, classes=1).to(conf.device)

source_train_loader = DataLoader(source_train_ds, batch_size=conf.batch_size)
target_train_loader = DataLoader(target_train_ds, batch_size=conf.batch_size)
source_val_loader = DataLoader(source_val_ds, batch_size=conf.batch_size)
target_val_loader = DataLoader(target_val_ds, batch_size=conf.batch_size)
target_test_loader = DataLoader(test_ds, batch_size=conf.batch_size)

opt = torch.optim.AdamW(net.parameters(), lr=conf.lr, weight_decay=conf.weight_decay)

num_training_steps = conf.num_epochs * len(source_train_loader) // conf.gradient_accumulation_steps
scheduler = get_scheduler(
    name=conf.scheduler,
    optimizer=opt,
    num_warmup_steps=conf.frac_warmup * num_training_steps,
    num_training_steps=num_training_steps
)

if conf.loss_type == LossType.DIVDIS:
    loss_fn = DivDisLoss(heads=2)
elif conf.loss_type == LossType.TOPK:
    loss_fn = ACELoss(
        heads=2, 
        classes=2, 
        binary=True, 
        mode="topk", 
        mix_rate=conf.mix_rate_lower_bound, 
        pseudo_label_all_groups=False, 
        device=conf.device
    )

In [20]:
def compute_src_losses(logits, y, gl, binary, use_group_labels):
    logits_chunked = torch.chunk(logits, conf.heads, dim=-1)
    labels = torch.cat([y, y], dim=-1) if not use_group_labels else gl
    labels_chunked = torch.chunk(labels, conf.heads, dim=-1)
    if binary:
        losses = [F.binary_cross_entropy_with_logits(logit.squeeze(), y.squeeze().to(torch.float32)) for logit, y in zip(logits_chunked, labels_chunked)]
    else:
        losses = [F.cross_entropy(logit.squeeze(), y.squeeze().to(torch.long)) for logit, y in zip(logits_chunked, labels_chunked)]
    return losses

def compute_corrects(logits: torch.Tensor, head: int, y: torch.Tensor, binary: bool):
    if binary:
        return ((logits[:, head] > 0) == y.flatten()).sum().item()
    else:
        logits = logits.view(logits.size(0), conf.heads, -1)
        return (logits[:, head].argmax(dim=-1) == y).sum().item()
        

In [24]:
# TODO: compute loss over accumulated batch (can I do that? maybe not...)

torch.Size([2, 2])

In [21]:
# TODO: change diciotary values to source loss, target loss

classes = 2

alt_index = 1

from collections import defaultdict
from tqdm import tqdm
from torch.utils.tensorboard import SummaryWriter
import numpy as np
from torch.amp import GradScaler, autocast
from sklearn.metrics import roc_auc_score

# metrics
metrics = defaultdict(list)
writer = SummaryWriter(log_dir=conf.exp_dir)


# grad_scaler = GradScaler()

# For now, I'll just compte all the logits with no grad, then iterate over in mini-batches, 
# and replace the logits at the index of the batch with the graident version
target_iter = iter(target_train_loader)
for epoch in range(conf.epochs):
    target_logit_ls = []
    for batch_idx, (x, y, gl) in tqdm(enumerate(source_train_loader), desc="Source train", total=len(source_train_loader)):
        x, y, gl = to_device(x, y, gl, conf.device)
        # with autocast(conf.device, enabled=conf.mixed_precision):
        logits = net(x)
        losses = compute_src_losses(logits, y, gl, conf.binary, conf.use_group_labels)
        xent = sum(losses)
        writer.add_scalar("train/source_loss", xent.item(), epoch * len(source_train_loader) + batch_idx)
        
        # # compute target logits with no grad 
        
        # target loss 
        try: 
            target_batch = next(target_iter)
        except StopIteration:
            target_iter = iter(target_train_loader)
            target_batch = next(target_iter)
        target_x, target_y, target_gl = to_device(*target_batch, conf.device)
        # with autocast(conf.device, enabled=conf.mixed_precision):
        target_logits = net(target_x)
        target_loss = loss_fn(target_logits)
        writer.add_scalar("train/target_loss", target_loss.item(), epoch * len(target_train_loader) + batch_idx)
        writer.add_scalar("train/weighted_target_loss", conf.aux_weight * target_loss.item(), epoch * len(target_train_loader) + batch_idx)
        # full loss 
        full_loss = conf.source_weight * xent + conf.aux_weight * target_loss
        # grad_scaler.scale(full_loss).backward()
        full_loss.backward()
        if (batch_idx + 1) % conf.gradient_accumulation_steps == 0:
            # unscale and clip gradients
            # grad_scaler.unscale_(opt)
            # torch.nn.utils.clip_grad_norm_(net.parameters(), 1.0)

            # update weights, clear gradients 
            # grad_scaler.step(opt)
            opt.step()
            # grad_scaler.update()
            if scheduler is not None:
                scheduler.step()
            opt.zero_grad()

        metrics[f"xent"].append(xent.item())
        metrics[f"repulsion_loss"].append(target_loss.item())
    # Compute loss on target validation set (used for model selection)
    # and aggregate metrics over the entire test set (should not really be using)
    if (epoch + 1) % 1 == 0:
        net.eval()
        # compute repulsion loss on target validation set (used for model selection)
        target_losses_val = []
        weighted_target_losses_val = []
        with torch.no_grad():
            for batch in tqdm(target_val_loader, desc="Target val"):
                x, y, gl = to_device(*batch, conf.device)
                logits_val = net(x)
                target_loss_val = loss_fn(logits_val)
                if not target_loss_val.isnan():
                    target_losses_val.append(target_loss_val.item())
                    weighted_target_losses_val.append(conf.aux_weight * target_loss_val.item())
        metrics[f"target_val_repulsion_loss"].append(np.mean(target_losses_val))
        metrics[f"target_val_weighted_repulsion_loss"].append(np.mean(weighted_target_losses_val))
        writer.add_scalar("val/target_loss", np.mean(target_losses_val), epoch)
        writer.add_scalar("val/weighted_target_loss", np.mean(weighted_target_losses_val), epoch)
        # compute xent on source validation set
        xent_val = []
        with torch.no_grad():
            for batch in tqdm(source_val_loader, desc="Source val"):
                x, y, gl = to_device(*batch, conf.device)
                logits_val = net(x)
                losses_val = compute_src_losses(logits_val, y, gl, conf.binary, conf.use_group_labels)
                xent_val.append(sum(losses_val).item())
        metrics[f"source_val_xent"].append(np.mean(xent_val))
        metrics[f"val_loss"].append(np.mean(target_losses_val) + np.mean(xent_val))
        metrics[f"val_weighted_loss"].append(np.mean(weighted_target_losses_val) + np.mean(xent_val))
        writer.add_scalar("val/source_loss", np.mean(xent_val), epoch)
        writer.add_scalar("val/val_loss", np.mean(target_losses_val) + np.mean(xent_val), epoch)
        writer.add_scalar("val/weighted_val_loss", np.mean(weighted_target_losses_val) + np.mean(xent_val), epoch)
        
        # compute accuracy over target test set (used to evaluate actual performance)
        total_correct = torch.zeros(conf.heads)
        total_correct_alt = torch.zeros(conf.heads)
        total_samples = 0

        # store predictions
        all_preds = [[] for _ in range(conf.heads)]
        all_preds_alt = [[] for _ in range(conf.heads)]
        all_labels = []
        all_labels_alt = []

        with torch.no_grad():
            for test_batch in tqdm(target_test_loader, desc="Target test"):
                test_x, test_y, test_gl = to_device(*test_batch, conf.device)
                test_logits = net(test_x)
                assert test_logits.shape == (batch_size(test_x), conf.heads * (1 if conf.binary else classes))
                total_samples += test_y.size(0)

                 # Store labels for AUROC
                all_labels.extend(test_y.cpu().numpy())
                all_labels_alt.extend(test_gl[:, alt_index].cpu().numpy())
                
                for i in range(conf.heads):
                    total_correct[i] += compute_corrects(test_logits, i, test_y, conf.binary)
                    total_correct_alt[i] += compute_corrects(test_logits, i, test_gl[:, alt_index], conf.binary)
                    probs = torch.sigmoid(test_logits[:, i]).cpu().numpy()
                    all_preds[i].extend(probs)
        
        # Compute and store AUROC for each head
        for i in range(conf.heads):
            auroc = roc_auc_score(all_labels, all_preds[i])
            auroc_alt = roc_auc_score(all_labels_alt, all_preds[i])
            metrics[f"epoch_auroc_{i}"].append(auroc)
            metrics[f"epoch_auroc_{i}_alt"].append(auroc_alt)
            writer.add_scalar(f"val/auroc_{i}", auroc, epoch)
            writer.add_scalar(f"val/auroc_{i}_alt", auroc_alt, epoch)
            print(f"Epoch {epoch + 1} AUROC {i}: {auroc:.4f}, Alt: {auroc_alt:.4f}")

        # compute and store accuracy for each head
        for i in range(conf.heads):
            metrics[f"epoch_acc_{i}"].append((total_correct[i] / total_samples).item())
            metrics[f"epoch_acc_{i}_alt"].append((total_correct_alt[i] / total_samples).item())
            writer.add_scalar(f"val/acc_{i}", (total_correct[i] / total_samples).item(), epoch)
            writer.add_scalar(f"val/acc_{i}_alt", (total_correct_alt[i] / total_samples).item(), epoch)
        
        print(f"Epoch {epoch + 1} Test Accuracies:")
        # print validation losses
        print(f"Target val repulsion loss: {metrics[f'target_val_repulsion_loss'][-1]:.4f}")
        print(f"Target val weighted repulsion loss: {metrics[f'target_val_weighted_repulsion_loss'][-1]:.4f}")
        print(f"Source val xent: {metrics[f'source_val_xent'][-1]:.4f}")
        print(f"val loss: {metrics[f'val_loss'][-1]:.4f}")
        print(f"val weighted loss: {metrics[f'val_weighted_loss'][-1]:.4f}")
        for i in range(conf.heads):
            print(f"Head {i}: {metrics[f'epoch_acc_{i}'][-1]:.4f}, Alt: {metrics[f'epoch_acc_{i}_alt'][-1]:.4f}")
        
        
        net.train()

metrics = dict(metrics)
# save metrics 
import json 
with open(f"{conf.exp_dir}/metrics.json", "w") as f:
    json.dump(metrics, f, indent=4)

Source train: 100%|██████████| 6/6 [00:04<00:00,  1.25it/s]
Target val: 100%|██████████| 32/32 [00:04<00:00,  7.52it/s]
  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)
Source val: 100%|██████████| 2/2 [00:00<00:00,  7.53it/s]
Target test: 100%|██████████| 32/32 [00:04<00:00,  7.48it/s]


Epoch 1 AUROC 0: 0.5896, Alt: 0.3835
Epoch 1 AUROC 1: 0.1682, Alt: 0.0777
Epoch 1 Test Accuracies:
Target val repulsion loss: nan
Target val weighted repulsion loss: nan
Source val xent: 4.0897
val loss: nan
val weighted loss: nan
Head 0: 0.5625, Alt: 0.2500
Head 1: 0.3594, Alt: 0.1406


Source train: 100%|██████████| 6/6 [00:04<00:00,  1.38it/s]
Target val:  88%|████████▊ | 28/32 [00:03<00:00,  7.45it/s]


KeyboardInterrupt: 

In [None]:
all_preds[i]

In [None]:
all_labels, all_preds[i]