In [1]:
import pandas as pd
from transformers import GPT2Tokenizer, GPT2LMHeadModel
import numpy as np
import torch
from model_util import load_checkpoint, set_extra_embeddings,set_separate_lm_head, set_separate_embeddings, set_transformed_lm_head


In [2]:
def load_data(path):
    data = []
    dataframe = pd.read_csv(path, sep = '\t')
    sentences = dataframe['sentence'].to_list()
    labels = dataframe['label'].to_list()
    for i,j in zip(sentences, labels):
        data.append((i,j))
    return data

In [3]:
train_data = load_data('/Users/kruthaydonapati/Downloads/Channel-LM-Prompting-main/data/k-shot/SST-2/16-13/train.tsv')

In [4]:
test_data = load_data('/Users/kruthaydonapati/Downloads/Channel-LM-Prompting-main/data/k-shot/SST-2/16-13/test.tsv')

In [5]:
def generate_templates():
    generic_template = "A %s one . "
    labels = [ 'terrible' , 'great']
    templates = [generic_template % label for label in labels]
    return templates

In [6]:
templates = generate_templates()

In [7]:
def prepare_data_for_training(train_data, tokenizer, templates, max_length = 128, max_length_per_example = 128):
    bos_token_id = tokenizer.bos_token_id
    eos_token_id = tokenizer.eos_token_id
    
    templates = [template.strip() for template in templates]
    train_data = [(" "+sent, label) for sent, label in train_data]
    prefixes_with_space = [tokenizer(" "+template)["input_ids"] for template in templates]
    train_inputs = [tokenizer(sent)["input_ids"] for sent, _ in train_data]
    truncated = np.sum([len(inputs)>max_length_per_example-16 for inputs in train_inputs])
    if truncated > 0:
        train_inputs = [inputs[:max_length_per_example-16] for inputs in train_inputs]
        print ("%d/%d truncated" % (truncated, len(train_inputs)))

    prefixes = [tokenizer(template)["input_ids"] for template in templates]
    idx = [idx for idx, _prefixes in enumerate(zip(*prefixes)) 
           if not np.all([_prefixes[0]==_prefix for _prefix in _prefixes])][0]
    input_ids, attention_mask, token_type_ids = [], [], []
    for train_input, dp in zip(train_inputs, train_data):
        prefix = prefixes[int(dp[1])]
        encoded = prepro_sentence_pair_single(prefix, train_input, max_length, bos_token_id, eos_token_id)
        input_ids.append(encoded[0])
        attention_mask.append(encoded[1])
        token_type_ids.append(encoded[2])

    return dict(input_ids=torch.LongTensor(input_ids),
                attention_mask=torch.LongTensor(attention_mask),
                token_type_ids=torch.LongTensor(token_type_ids))

In [34]:
def prepare_data_for_all(train_data, test_data, tokenizer, templates, max_length = 128, max_length_per_example = 128):
    bos_token_id = tokenizer.bos_token_id
    eos_token_id = tokenizer.eos_token_id
    templates = [template.strip() for template in templates]
    test_data = [(" "+sent, label) for sent, label in test_data]
    train_data = [(" "+sent, label) for sent, label in train_data]
    prefixes_with_space = [tokenizer(" "+template)["input_ids"] for template in templates]
    test_inputs = [tokenizer(sent)["input_ids"] for sent, _ in test_data]
    truncated = np.sum([len(inputs)>max_length_per_example-16 for inputs in test_inputs])
    if truncated > 0:
        test_inputs = [inputs[:max_length_per_example-16] for inputs in test_inputs]
        print ("%d/%d truncated" % (truncated, len(test_inputs)))
    prefixes = [tokenizer(template)["input_ids"] for template in templates]
    idx = [idx for idx, _prefixes in enumerate(zip(*prefixes))
           if not np.all([_prefixes[0]==_prefix for _prefix in _prefixes])][0]
    '''    for i in range(n_classes):
        for j in range(i+1, n_classes):
            assert prefixes[i][:idx]==prefixes[j][:idx]
            assert prefixes[i][idx]!=prefixes[j][idx] '''
    input_tensors = []
    for i in range(n_classes):
        prefix = prefixes[i].copy()
        tensor = prepro_sentence_pair([prefix], test_inputs, max_length,bos_token_id, 
                                      eos_token_id,allow_truncation=use_demonstrations)
    input_tensors.append(tensor)
    return input_tensors

In [8]:
n_prefix = 20

In [32]:
def prepro_sentence_pair(train_inputs, test_inputs, max_length,
                         bos_token_id, eos_token_id,
                         allow_truncation=False):
    input_ids, attention_mask, token_type_ids = [], [], []
    for test_input in test_inputs:
        for train_input in train_inputs:
            _input_ids, _attention_mask, _token_type_ids = \
                prepro_sentence_pair_single(train_input, test_input, max_length,
                                            bos_token_id, eos_token_id,
                                            allow_truncation=allow_truncation)
            input_ids.append(_input_ids)
            attention_mask.append(_attention_mask)
            token_type_ids.append(_token_type_ids)

    return {"input_ids": torch.LongTensor(input_ids),
            "attention_mask": torch.LongTensor(attention_mask),
            "token_type_ids": torch.LongTensor(token_type_ids)}


In [12]:
def prepro_sentence_pair_single(ids1, ids2, max_length,
                                bos_token_id, eos_token_id, negate=False,
                                allow_truncation=False):
    if bos_token_id is not None:
        ids1 = [bos_token_id] + ids1
    if eos_token_id is not None:
        ids2 = ids2 + [eos_token_id]
    if allow_truncation and len(ids1)+len(ids2) > max_length:
        ids1 = ids1[len(ids1)+len(ids2)-max_length:] # len = max_length-len(ids2)
        assert len(ids1)+len(ids2)==max_length

    n_mask = max_length-len(ids1)-len(ids2)
    assert n_mask>=0, (max_length, len(ids1), len(ids2))
    input_ids = ids1+ids2+[0 for _ in range(n_mask)]
    attention_mask = [1 for _ in ids1+ids2] + [0 for _ in range(n_mask)]
    if negate:
        token_type_ids = [0 for _ in ids1] + [-1 for _ in ids2] + [0 for _ in range(n_mask)]
    else:
        token_type_ids = [0 for _ in ids1] + [1 for _ in ids2] + [0 for _ in range(n_mask)]
    return input_ids, attention_mask, token_type_ids

In [35]:
def run(train_data, test_data, templates):
    tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
    inputs = prepare_data_for_training(train_data, tokenizer, templates)
    model = GPT2LMHeadModel.from_pretrained('gpt2')
    for param in model.parameters():
        param.requires_grad = False
    #set_extra_embeddings(model, n_prefix)
    #inputs = prepend_task_tokens(tokenizer, inputs, n_prefix)
    model = model.cuda()
    if torch.cuda.device_count() > 1:
        model = torch.nn.DataParallel(model)
    train(inputs)
    input_tensors = prepare_data_for_all(train_data, test_data, tokenizer, templates, max_length = 128, max_length_per_example = 128)
    
    input_tensors = prepend_task_tokens(tokenizer, input_tensors, n_prefix)
    

In [36]:
def train(model, inputs, batch_size,learning_rate=1e-5, warmup_steps=50, num_training_steps=200,
          gradient_accumulation_steps=1, max_grad_norm=1.0, eval_period=20):
    optimizer, scheduler = get_optimizer_and_scheduler("adamw", model, learning_rate=learning_rate, 
                                                       warmup_steps=warmup_steps, 
                                                       num_training_steps=num_training_steps)
    
    dataloader = get_dataloader(inputs, batch_size, is_training=True)
    n_trainable_params = len([param for param in model.parameters() if param.requires_grad])
    n_gpus = torch.cuda.device_count()
    model.train()
    global_step = 0
    train_losses = []
    best_accuracy = -1
    stop_training=False
    for epoch in range(num_training_steps):
        for batch in dataloader:
            global_step += 1

            input_ids=batch[0].cuda()
            attention_mask=batch[1].cuda()
            token_type_ids=batch[2].cuda()
            
            if len(batch)==3:
                labels=None
            else:
                labels=batch[3].cuda()

            loss = run_model(model, input_ids, attention_mask, token_type_ids, labels=labels)
            loss = loss.mean()
            
            if torch.isnan(loss).data:
                print ("Stop training because loss=%s" % (loss.data))
                stop_training=True
                break

            train_losses.append(loss.detach().cpu())
            
            loss.backward()
            
            if global_step % gradient_accumulation_steps == 0:
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
                optimizer.step()    # We have accumulated enought gradients
                model.zero_grad()
                if scheduler is not None:
                    scheduler.step()
            
            if global_step % eval_period == 0:
                keys = ["transformer.wte.new_embed.weight"]
                model_state_dict = {key: model.state_dict()[key if n_gpus==1 else "module."+key].cpu() for key in keys}
                
                torch.save(model_state_dict,
                           os.path.join(output_dir, "model-{}.pt".format(global_step)))

                train_losses = []
            if global_step==num_training_steps:
                break
        if global_step==num_training_steps:
            break

In [37]:
def prepend_task_tokens(tokenizer, inputs, n_prefix):
    task_tokens = ["<TASK{}>".format(str(i).zfill(2)) for i in range(n_prefix)]
    tokenizer.add_tokens(task_tokens)
    task_token_ids = tokenizer(" ".join(task_tokens), return_tensors="pt")["input_ids"]
    assert task_token_ids.shape[-1]==n_prefix

    def convert(inputs):
        n_train = inputs["input_ids"].shape[0]

        new_input_ids=torch.cat([
                task_token_ids.repeat(n_train, 1),
                inputs["input_ids"][:,1:]], 1)

        inputs = dict(
            input_ids=new_input_ids,
            attention_mask=torch.cat([
                torch.ones((n_train, n_prefix-1), dtype=torch.long),
                inputs["attention_mask"]], 1),
            token_type_ids=torch.cat([
                torch.zeros((n_train, n_prefix-1), dtype=torch.long),
                inputs["token_type_ids"]], 1),
            labels=torch.cat([
                torch.zeros((n_train, n_prefix-1), dtype=torch.long),
                inputs["input_ids"]], 1))
        return inputs

    if type(inputs)==list:
        return [convert(_inputs) for _inputs in inputs]

    return convert(inputs)

In [39]:
run(train_data, test_data, templates)

AssertionError: Torch not compiled with CUDA enabled