### Example code for using SubMix

In [None]:
#First, start off by getting your model ensemble. 
#You can use X to download the wikitext-103 dataset.
#If you don't want to train your own, you can access an 8-fold partition from Y.
#You'll also need a public LM from Huggingface

In [None]:
#Set the configurations here
B = 1024 #query budget
eps = 2 #target epsilon (information leakage)
alpha = 2 #alpha for Renyi Divergence
seqlen = 512 #sequence length 


In [None]:
import torch

from transformers import GPT2LMHeadModel
public_model = GPT2LMHeadModel.from_pretrained('gpt2')

import os
#download 8 fold ensembles for wikitext-103
s3path = 'https://dl.fbaipublicfiles.com/submix/ensembles'
s3path += '/gpt2_wikitext_finetune_8fold'
ensemble = []
for i in range(8):
    print(f'{s3path}/{i}')
    os.system(f'wget {s3path}/{i}')
    ensemble.append(torch.load(f'{i}').to('cpu'))

In [None]:
from submix import SubMix

SM = SubMix('cpu', B, eps, public_model, ensemble, alpha=alpha)

In [None]:
# some utility functions and classes defined here
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
def load_wikitext():
    corpus = dict()
    for dset in ["valid", "train", "test"]:
        corpus[dset] = torch.load(f"wikitext-103-{dset}-corpus.pt")
    return corpus

class CorpusDataset(Dataset):
    def __init__(self, corpus, seqlen):
        super().__init__()
        self.corpus = corpus
        self.seqlen = seqlen

    def __len__(self):
        return int(len(self.corpus) / self.seqlen)

    def __getitem__(self, item):
        idx = item * self.seqlen
        return self.corpus[idx : idx + self.seqlen]

def XHeval(lm_logits, labels):
    #computes cross entropy
    shift_logits = lm_logits[..., :-1, :].contiguous()
    shift_labels = labels[..., 1:].contiguous()
    # Flatten the tokens
    XH = nn.CrossEntropyLoss()
    return XH(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)).item()
    
def test(
    SM,
    device,
    loader,
    max_iters
):
    for model in SM.LMs:
        model.eval()
    losses = []
    with torch.no_grad():
        for i, data in enumerate(tqdm.tqdm(loader)):
            data = data.to(device)
            if isinstance(data,str):
                data = torch.tensor(tokenizer.encode(x)).to(device)
            L, P = SM.compute_logits_at_context(data)
            L_submix = []
            for j in range(len(P[0])):
                prob_submix = SM.query( [p[j] for p in P] )
                L_submix.append(torch.log(prob_submix))
            L_submix = torch.stack(L_submix)
            losses.append(XHeval(L_submix, data))
            #losses.append(loss.item())
            if i >= max_iters:
                break
        print(f"Val Loss: {np.mean(losses):.4f} ")
    return losses


In [None]:
import tqdm
corpus = load_wikitext() #load data
val_loader = DataLoader(CorpusDataset(corpus['valid'], seqlen)) #define data loader
L = test(SM, 'cpu', val_loader, B/512) # compute PPL
#NOTE: The query budget is on a per token basis, not a per sequence basis
#therefore, we recommend running test a large number of times and taking 
#an average value in order to get a good estimate of the PPL