In [1]:
import os
import sys
import pandas as pd
import numpy as np

from datasets import load_dataset, load_from_disk
import nltk 
nltk.download('punkt')

import evaluate

[nltk_data] Downloading package punkt to
[nltk_data]     /home/RDC/zinovyee.hub/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


In [2]:
dataset_curated = load_from_disk('../data/raw/conala')

In [3]:
train_df = pd.DataFrame(data={'question_id': dataset_curated['train']['question_id'], 
                            'intent' : dataset_curated['train']['intent'],
                            'rewritten_intent' : dataset_curated['train']['rewritten_intent'],
                            'snippet' : dataset_curated['train']['snippet']})

test_df = pd.DataFrame(data={'question_id': dataset_curated['test']['question_id'], 
                            'intent' : dataset_curated['test']['intent'],
                            'rewritten_intent' : dataset_curated['test']['rewritten_intent'],
                            'snippet' : dataset_curated['test']['snippet']})

full_df = pd.concat([train_df, test_df], axis=0)

In [4]:
full_df = full_df.sort_values("question_id").reset_index(drop=True)

In [5]:
print(full_df.question_id.nunique())

2074


In [6]:
qids = full_df.question_id.unique()
qids.sort()
print(qids)

[    1476     1854     3061 ... 42731970 42747987 42765620]


In [37]:
(2074-900)/35

33.542857142857144

In [8]:
i = 15 
(i+1)*900

14400

In [39]:
first_train_ids = qids[:900]
batches = []
batch_size = 35

for i in range(34):
    print(i)

    batch_start = 900+(i)*batch_size 
    print(batch_start)
    
    if i!=33:
        batch_end = batch_start + batch_size
        batches.append(qids[batch_start:batch_end])
        print(batch_end)
    else: 
        batches.append(qids[batch_start:])
        print(len(qids)-1)

0
900
935
1
935
970
2
970
1005
3
1005
1040
4
1040
1075
5
1075
1110
6
1110
1145
7
1145
1180
8
1180
1215
9
1215
1250
10
1250
1285
11
1285
1320
12
1320
1355
13
1355
1390
14
1390
1425
15
1425
1460
16
1460
1495
17
1495
1530
18
1530
1565
19
1565
1600
20
1600
1635
21
1635
1670
22
1670
1705
23
1705
1740
24
1740
1775
25
1775
1810
26
1810
1845
27
1845
1880
28
1880
1915
29
1915
1950
30
1950
1985
31
1985
2020
32
2020
2055
33
2055
2073


In [40]:
full_df['t_batch'] = -1
for i, batch_ids in enumerate(batches):
    full_df.loc[full_df.question_id.isin(set(batch_ids)), 't_batch'] = i

In [41]:
def postprocess_text(preds, labels):

    preds = [pred.strip() for pred in preds]
    labels = [label.strip() for label in labels]

    # rougeLSum expects newline after each sentence
    preds  = ["\n".join(nltk.sent_tokenize(pred)) for pred in preds]
    labels = ["\n".join(nltk.sent_tokenize(label)) for label in labels]
    return preds, labels

In [42]:
def batch_tokenize_preprocess(batch, tokenizer, max_input_length, max_output_length):

    source = batch["input_sequence"]
    target = batch["output_sequence"]

    source_tokenized = tokenizer(
        source, padding="max_length",
        truncation=True, max_length=max_input_length
    )

    target_tokenized = tokenizer(
        target, padding="max_length",
        truncation=True, max_length=max_output_length
    )

    batch = {k: v for k, v in source_tokenized.items()}

    # Ignore padding in the loss

    batch["labels"] = [
        [-100 if token == tokenizer.pad_token_id else token for token in l]
        for l in target_tokenized["input_ids"]
    ]

    return batch

In [43]:
def compute_metric_with_params(tokenizer, metrics_list=['rouge', 'bleu']):
    def compute_metrics(eval_preds):
    
        preds, labels = eval_preds
    
        if isinstance(preds, tuple):
            preds = preds[0]
    
        decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
    
        # Replace -100 in the labels as we can't decode them.
        labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
        decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
    
        # POST PROCESSING
        decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels)
    
        results_dict = {}
        for m in metrics_list:
            metric = evaluate.load(m)
    
            if m=='bleu':
                result = metric.compute(
                    predictions=decoded_preds, references=decoded_labels
                )
            elif m=='rouge':
                result = metric.compute(
                    predictions=decoded_preds, references=decoded_labels, use_stemmer=True
                )
            result = {key: value for key, value in result.items() if key!='precisions'}
    
            prediction_lens = [
                np.count_nonzero(pred != tokenizer.pad_token_id) for pred in preds
            ]
            result["gen_len"] = np.mean(prediction_lens)
            result = {k: round(v, 4) for k, v in result.items()}
            results_dict.update(result)
        return results_dict
    return compute_metrics

In [44]:
def generate_summary(test_samples, model, tokenizer, encoder_max_length, decoder_max_length):
    inputs = tokenizer(
        test_samples["input_sequence"],
        padding="max_length",
        truncation=True,
        max_length=encoder_max_length,
        return_tensors="pt",
    )
    input_ids = inputs.input_ids.to(model.device)
    attention_mask = inputs.attention_mask.to(model.device)
    outputs = model.generate(input_ids, attention_mask=attention_mask, max_new_tokens=decoder_max_length)
    output_str = tokenizer.batch_decode(outputs, skip_special_tokens=True)
    return outputs, output_str

In [45]:
def train_t5(): 
    pass

In [46]:
import torch

In [47]:
# CREATE ANALYSIS FOLDER
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

model_name="Salesforce/codet5-base-multi-sum"

cuda


In [48]:
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, Seq2SeqTrainingArguments, DataCollatorForSeq2Seq, Seq2SeqTrainer

In [49]:
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name, skip_special_tokens=False)

In [50]:
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name, skip_special_tokens=False)

if True:
    for parameter in model.parameters():
        parameter.requires_grad = False
    for i, m in enumerate(model.decoder.block):        
        #Only un-freeze the last n transformer blocks in the decoder
        if i+1 > 12 - 4:
            for parameter in m.parameters():
                parameter.requires_grad = True

In [51]:
DECODER_LENGTH = 20
ENCODER_LENGTH = 15

In [52]:
model.to(device)
print(device)

cuda


In [53]:
from datasets import Dataset

In [54]:
def prep_for_hf(df: pd.DataFrame, batch_id: int): 
    df = df.rename(columns={'snippet': 'input_sequence', 
                    'intent' : 'output_sequence'})
    df = df.loc[df.t_batch==batch_id, ['input_sequence', 'output_sequence']]
    df = df.sample(frac=1, random_state=42)
    return Dataset.from_pandas(df)

In [55]:
train_dataset = prep_for_hf(full_df, -1)
test_dataset = prep_for_hf(full_df, 0)


In [57]:
train_data_txt = train_dataset
    
validation_data_txt = test_dataset

train_data = train_data_txt.map(
    lambda batch: batch_tokenize_preprocess(
        batch, 
        tokenizer=tokenizer,
        max_input_length=ENCODER_LENGTH,
        max_output_length=DECODER_LENGTH
    ),
    batch_size=8,
    batched=True,
    remove_columns=train_data_txt.column_names,
)

validation_data = validation_data_txt.map(
    lambda batch: batch_tokenize_preprocess(
        batch, 
        tokenizer=tokenizer,
        max_input_length=ENCODER_LENGTH,
        max_output_length=DECODER_LENGTH
    ),
    batched=True,
    remove_columns=validation_data_txt.column_names,
)


# SUBSAMPLE FOR GENERATION BEFORE TUNING
test_samples = validation_data_txt.select(range(20))
summaries_before_tuning = generate_summary(test_samples, 
                                            model, 
                                            tokenizer, 
                                            ENCODER_LENGTH,
                                            DECODER_LENGTH)[1]

Map:   0%|          | 0/1427 [00:00<?, ? examples/s]

Map:   0%|          | 0/42 [00:00<?, ? examples/s]

In [58]:
eval_columns_list = [
                                "eval_loss",
                                "eval_rouge1",
                                "eval_rouge2",
                                "eval_rougeL",
                                "eval_rougeLsum",
                                "eval_bleu",
                                "eval_gen_len",
                            ]

In [59]:
os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["WANDB_DISABLED"] = "true"
os.environ["CUDA_VISIBLE_DEVICES"] = "1"

In [63]:
import gc
gc.collect()

torch.cuda.empty_cache()
training_args = Seq2SeqTrainingArguments(
    output_dir=f"rep/results",
    num_train_epochs=2,
    do_train=True,
    do_eval=True,
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    learning_rate=5e-4,
    warmup_steps=100,
    weight_decay=0.1,
    label_smoothing_factor=0.1,
    predict_with_generate=True,
    logging_dir=f"rep/logs",
    logging_steps=100,
    save_total_limit=1,
    report_to=None,
    save_strategy='epoch',
    logging_strategy='epoch',
    evaluation_strategy='epoch',
    load_best_model_at_end=False    
)

data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)

compute_metrics = compute_metric_with_params(tokenizer)

trainer = Seq2SeqTrainer(
model=model,
args=training_args,
data_collator=data_collator,
train_dataset=train_data,
eval_dataset=validation_data,
tokenizer=tokenizer,
compute_metrics=compute_metrics,
)

# ZERO - SHOT
results_zero_shot = trainer.evaluate()
results_zero_shot_df = pd.DataFrame(data=results_zero_shot, index=[0])[eval_columns_list]
results_zero_shot_df.loc[0, :] = results_zero_shot_df.loc[0, :].apply(lambda x: round(x, 3))
print(results_zero_shot_df)


# TRAINING
trainer.train()

# FINE-TUNING
results_fine_tune = trainer.evaluate()
results_fine_tune_df = pd.DataFrame(data=results_fine_tune, index=[0])[eval_columns_list]
results_fine_tune_df.loc[0, :] = results_fine_tune_df.loc[0, :].apply(lambda x: round(x, 3))
print(results_fine_tune_df)

Using the `WANDB_DISABLED` environment variable is deprecated and will be removed in v5. Use the --report_to flag to control the integrations used for logging result (for instance --report_to none).


   eval_loss  eval_rouge1  eval_rouge2  eval_rougeL  eval_rougeLsum  \
0      5.535        0.161        0.019        0.133           0.137   

   eval_bleu  eval_gen_len  
0        0.0         9.688  




Epoch,Training Loss,Validation Loss,Rouge1,Rouge2,Rougel,Rougelsum,Gen Len,Bleu,Brevity Penalty,Length Ratio,Translation Length,Reference Length
1,3.8736,4.13275,0.3368,0.1212,0.2796,0.2802,12.6042,0.0446,1.0,1.0022,462,461
2,2.937,4.130136,0.2843,0.0833,0.2451,0.2443,11.6667,0.0257,0.9397,0.9414,434,461




   eval_loss  eval_rouge1  eval_rouge2  eval_rougeL  eval_rougeLsum  \
0       4.13        0.284        0.083        0.245           0.244   

   eval_bleu  eval_gen_len  
0      0.026        11.667  


In [64]:
summaries_after_tuning = generate_summary(test_samples, 
                                                model,
                                                tokenizer,
                                                ENCODER_LENGTH,
                                                DECODER_LENGTH)[1]
    
for i, description in enumerate(test_samples["output_sequence"]):
        print('_'*10)
        print(f'Original: {description}')
        
        print('\n')
        print(f'Summary before Tuning: {summaries_before_tuning[i]}')
        print('\n')
        print(f'Summary after Tuning: {summaries_after_tuning[i]}')
        print('\n')
        print('_'*10)
        print('\n'*2)

__________
Original: Removing duplicate characters from a string


Summary before Tuning: Join the set with foo.


Summary after Tuning: Python: How to convert a string to a string?


__________



__________
Original: finding index of an item closest to the value in a list that's not entirely sorted


Summary before Tuning: Returns the minimum element in an array.


Summary after Tuning: How to get the largest item in a list?


__________



__________
Original: (Django) how to get month name?


Summary before Tuning: Return today s string representation


Summary after Tuning: get current time


__________



__________
Original: sorting values of python dict using sorted builtin function


Summary before Tuning: Returns a sorted list of the keys in the dictionary.


Summary after Tuning: iterate over a dictionary in sorted order


__________



__________
Original: Convert an IP string to a number and vice versa


Summary before Tuning: Returns a string with the IPv6 address.


Summ

In [None]:
loss

In [65]:
loss = []
rouge_1 = []
for i in range(1, 34): 
    print(i)
    test_dataset = prep_for_hf(full_df, i)
    validation_data_txt = test_dataset
    validation_data = validation_data_txt.map(
        lambda batch: batch_tokenize_preprocess(
            batch, 
            tokenizer=tokenizer,
            max_input_length=ENCODER_LENGTH,
            max_output_length=DECODER_LENGTH
        ),
        batched=True,
        remove_columns=validation_data_txt.column_names,
    )


    trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    data_collator=data_collator,
    train_dataset=train_data,
    eval_dataset=validation_data,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics,
    )

    # ZERO - SHOT
    results_zero_shot = trainer.evaluate()
    results_zero_shot_df = pd.DataFrame(data=results_zero_shot, index=[0])[eval_columns_list]
    results_zero_shot_df.loc[0, :] = results_zero_shot_df.loc[0, :].apply(lambda x: round(x, 3))
    print(results_zero_shot_df)
    rouge_1.append(results_zero_shot_df['eval_rouge1'])
    loss.append(results_zero_shot_df['eval_loss'])


1


Map:   0%|          | 0/53 [00:00<?, ? examples/s]



   eval_loss  eval_rouge1  eval_rouge2  eval_rougeL  eval_rougeLsum  \
0      3.546        0.411        0.178        0.373           0.373   

   eval_bleu  eval_gen_len  
0      0.066        10.698  
2


Map:   0%|          | 0/40 [00:00<?, ? examples/s]



   eval_loss  eval_rouge1  eval_rouge2  eval_rougeL  eval_rougeLsum  \
0       3.73        0.386        0.138        0.351            0.35   

   eval_bleu  eval_gen_len  
0      0.053         9.825  
3


Map:   0%|          | 0/45 [00:00<?, ? examples/s]



   eval_loss  eval_rouge1  eval_rouge2  eval_rougeL  eval_rougeLsum  \
0      3.682        0.382        0.135        0.343           0.345   

   eval_bleu  eval_gen_len  
0      0.084        12.133  
4


Map:   0%|          | 0/40 [00:00<?, ? examples/s]



   eval_loss  eval_rouge1  eval_rouge2  eval_rougeL  eval_rougeLsum  \
0      3.942        0.338        0.106        0.276           0.277   

   eval_bleu  eval_gen_len  
0      0.055        12.225  
5


Map:   0%|          | 0/39 [00:00<?, ? examples/s]



   eval_loss  eval_rouge1  eval_rouge2  eval_rougeL  eval_rougeLsum  \
0      3.952        0.343        0.089        0.284           0.284   

   eval_bleu  eval_gen_len  
0        0.0        11.692  
6


Map:   0%|          | 0/46 [00:00<?, ? examples/s]



   eval_loss  eval_rouge1  eval_rouge2  eval_rougeL  eval_rougeLsum  \
0      3.819        0.414        0.137        0.354           0.354   

   eval_bleu  eval_gen_len  
0        0.0          12.0  
7


Map:   0%|          | 0/49 [00:00<?, ? examples/s]



   eval_loss  eval_rouge1  eval_rouge2  eval_rougeL  eval_rougeLsum  \
0       3.85        0.362        0.102        0.322           0.323   

   eval_bleu  eval_gen_len  
0        0.0        11.326  
8


Map:   0%|          | 0/51 [00:00<?, ? examples/s]



   eval_loss  eval_rouge1  eval_rouge2  eval_rougeL  eval_rougeLsum  \
0      4.112        0.308        0.086        0.278           0.278   

   eval_bleu  eval_gen_len  
0      0.038        12.353  
9


Map:   0%|          | 0/46 [00:00<?, ? examples/s]



   eval_loss  eval_rouge1  eval_rouge2  eval_rougeL  eval_rougeLsum  \
0      4.003        0.355        0.141        0.333           0.332   

   eval_bleu  eval_gen_len  
0      0.052        11.217  
10


Map:   0%|          | 0/48 [00:00<?, ? examples/s]



   eval_loss  eval_rouge1  eval_rouge2  eval_rougeL  eval_rougeLsum  \
0       4.13        0.284        0.083        0.245           0.244   

   eval_bleu  eval_gen_len  
0      0.026        11.667  
11


Map:   0%|          | 0/37 [00:00<?, ? examples/s]



   eval_loss  eval_rouge1  eval_rouge2  eval_rougeL  eval_rougeLsum  \
0      3.972        0.355        0.113        0.321           0.322   

   eval_bleu  eval_gen_len  
0        0.0        11.514  
12


Map:   0%|          | 0/47 [00:00<?, ? examples/s]



   eval_loss  eval_rouge1  eval_rouge2  eval_rougeL  eval_rougeLsum  \
0      3.884        0.376        0.108        0.338           0.339   

   eval_bleu  eval_gen_len  
0        0.0        10.979  
13


Map:   0%|          | 0/48 [00:00<?, ? examples/s]



   eval_loss  eval_rouge1  eval_rouge2  eval_rougeL  eval_rougeLsum  \
0       4.06        0.356        0.134        0.319           0.319   

   eval_bleu  eval_gen_len  
0      0.058        11.896  
14


Map:   0%|          | 0/40 [00:00<?, ? examples/s]



   eval_loss  eval_rouge1  eval_rouge2  eval_rougeL  eval_rougeLsum  \
0      4.004        0.398        0.133        0.356           0.355   

   eval_bleu  eval_gen_len  
0      0.058          12.1  
15


Map:   0%|          | 0/35 [00:00<?, ? examples/s]



   eval_loss  eval_rouge1  eval_rouge2  eval_rougeL  eval_rougeLsum  \
0       3.82        0.344        0.121        0.288           0.287   

   eval_bleu  eval_gen_len  
0      0.036        13.286  
16


Map:   0%|          | 0/42 [00:00<?, ? examples/s]



   eval_loss  eval_rouge1  eval_rouge2  eval_rougeL  eval_rougeLsum  \
0      4.045        0.348        0.142        0.315           0.315   

   eval_bleu  eval_gen_len  
0      0.077        12.643  
17


Map:   0%|          | 0/43 [00:00<?, ? examples/s]



   eval_loss  eval_rouge1  eval_rouge2  eval_rougeL  eval_rougeLsum  \
0      4.159        0.294        0.082        0.252           0.253   

   eval_bleu  eval_gen_len  
0      0.045        11.302  
18


Map:   0%|          | 0/39 [00:00<?, ? examples/s]



   eval_loss  eval_rouge1  eval_rouge2  eval_rougeL  eval_rougeLsum  \
0      4.232        0.306        0.081        0.264           0.263   

   eval_bleu  eval_gen_len  
0      0.034        12.308  
19


Map:   0%|          | 0/42 [00:00<?, ? examples/s]



   eval_loss  eval_rouge1  eval_rouge2  eval_rougeL  eval_rougeLsum  \
0       4.18        0.306        0.091        0.265           0.268   

   eval_bleu  eval_gen_len  
0      0.032        10.881  
20


Map:   0%|          | 0/49 [00:00<?, ? examples/s]



   eval_loss  eval_rouge1  eval_rouge2  eval_rougeL  eval_rougeLsum  \
0      4.104        0.353        0.116        0.314           0.316   

   eval_bleu  eval_gen_len  
0       0.05          11.0  
21


Map:   0%|          | 0/44 [00:00<?, ? examples/s]



   eval_loss  eval_rouge1  eval_rouge2  eval_rougeL  eval_rougeLsum  \
0      4.232        0.313        0.076        0.266           0.267   

   eval_bleu  eval_gen_len  
0      0.035        12.909  
22


Map:   0%|          | 0/44 [00:00<?, ? examples/s]



   eval_loss  eval_rouge1  eval_rouge2  eval_rougeL  eval_rougeLsum  \
0      4.156        0.348        0.099        0.301           0.302   

   eval_bleu  eval_gen_len  
0       0.06        11.341  
23


Map:   0%|          | 0/45 [00:00<?, ? examples/s]



   eval_loss  eval_rouge1  eval_rouge2  eval_rougeL  eval_rougeLsum  \
0       4.17        0.305        0.093        0.277           0.279   

   eval_bleu  eval_gen_len  
0      0.039        11.267  
24


Map:   0%|          | 0/43 [00:00<?, ? examples/s]



   eval_loss  eval_rouge1  eval_rouge2  eval_rougeL  eval_rougeLsum  \
0      3.812        0.358        0.138        0.322           0.323   

   eval_bleu  eval_gen_len  
0      0.069        11.977  
25


Map:   0%|          | 0/41 [00:00<?, ? examples/s]



   eval_loss  eval_rouge1  eval_rouge2  eval_rougeL  eval_rougeLsum  \
0      4.541        0.367        0.094        0.309           0.309   

   eval_bleu  eval_gen_len  
0      0.042         12.61  
26


Map:   0%|          | 0/53 [00:00<?, ? examples/s]



   eval_loss  eval_rouge1  eval_rouge2  eval_rougeL  eval_rougeLsum  \
0      4.094         0.33        0.095          0.3             0.3   

   eval_bleu  eval_gen_len  
0      0.053        12.151  
27


Map:   0%|          | 0/39 [00:00<?, ? examples/s]



   eval_loss  eval_rouge1  eval_rouge2  eval_rougeL  eval_rougeLsum  \
0      4.351         0.29        0.079        0.241           0.241   

   eval_bleu  eval_gen_len  
0        0.0        11.795  
28


Map:   0%|          | 0/41 [00:00<?, ? examples/s]



   eval_loss  eval_rouge1  eval_rouge2  eval_rougeL  eval_rougeLsum  \
0      3.876        0.367        0.133        0.344           0.345   

   eval_bleu  eval_gen_len  
0      0.084        11.976  
29


Map:   0%|          | 0/40 [00:00<?, ? examples/s]



   eval_loss  eval_rouge1  eval_rouge2  eval_rougeL  eval_rougeLsum  \
0      4.025        0.305        0.121        0.267           0.268   

   eval_bleu  eval_gen_len  
0      0.046          12.6  
30


Map:   0%|          | 0/38 [00:00<?, ? examples/s]



   eval_loss  eval_rouge1  eval_rouge2  eval_rougeL  eval_rougeLsum  \
0       4.23        0.306        0.108        0.281            0.28   

   eval_bleu  eval_gen_len  
0      0.039        13.079  
31


Map:   0%|          | 0/42 [00:00<?, ? examples/s]



   eval_loss  eval_rouge1  eval_rouge2  eval_rougeL  eval_rougeLsum  \
0      4.101        0.339        0.134         0.32           0.318   

   eval_bleu  eval_gen_len  
0      0.075          12.0  
32


Map:   0%|          | 0/42 [00:00<?, ? examples/s]



   eval_loss  eval_rouge1  eval_rouge2  eval_rougeL  eval_rougeLsum  \
0      4.162        0.277        0.074        0.247           0.248   

   eval_bleu  eval_gen_len  
0       0.03        12.071  
33


Map:   0%|          | 0/19 [00:00<?, ? examples/s]



   eval_loss  eval_rouge1  eval_rouge2  eval_rougeL  eval_rougeLsum  \
0       4.25        0.367        0.096         0.31           0.316   

   eval_bleu  eval_gen_len  
0      0.037        11.632  


In [None]:
print(1)

In [None]:
rouge_1