In [24]:
import argparse
from transformers import AutoTokenizer, GPT2Model
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
from accelerate import Accelerator
from accelerate.utils import set_seed
from datasets import load_dataset, DatasetDict, Dataset
import torch
from transformers import AutoModelForCausalLM
from torch.optim import AdamW
from transformers import get_linear_schedule_with_warmup
from torch.nn import CrossEntropyLoss
import os
import re
import json
from sklearn.model_selection import train_test_split

In [4]:
class Args:
    def __init__(self):
        self.train_data = ["../extracted_text/kumar_and_clark/kumar_and_clark_top_3.txt","../extracted_text/Harrison/Harrison_top_3.txt"]
        self.test_data = ""
        self.model_name = "gpt2-xl"
        self.output_dir = f"../model/qa_models/{self.model_name}-medicine/"
    
def str_or_list(val):
    if re.search(r"^\[",val):
        sep_list = val.strip("[]").split(',')
        return sep_list
    return [val]

# parser = argparse.ArgumentParser()
# parser.add_argument("--train_data", help="Add input data files (single file name or list fo files in the format : [a,b,c,...]. The files in the list will be concatenated before being used as training data)", required=True, type=str_or_list)
# parser.add_argument("--test_data", help="Add testing data files (single file name or list fo files in the format : [a,b,c,...]. The files in the list will be concatenated before being used as training data)", required=False, type=str_or_list)
# parser.add_argument("--model_name", help="Model and tokenizer name", required=True, type=str)
# parser.add_argument("--output_dir", help="Directory to save the trained models and checkpoints", required=False, type=str, default="./")
# args = parser.parse_args()

args = Args()
# print("example :")
# print(dataset['train'][0])



In [5]:
with open("../extracted_text/short_cases_medicine/short_cases_medicine_annotated.json", 'r') as f:
    json_data = json.load(f)

In [8]:
resp = json_data['respiratory']
resp[:5]

[{'med_hist': "INSTRUCTION Examine this patient's chest. Examine this patient's chest from the back. Examine this patient's chest from the front. SALIENT FEATURES History · Fever. · Pleuritic pain (made worse on coughing or deep breathing). · Cough (pneumonia, TB). · Haemoptysis (associated parenchymal involvement in bronchogenic carcinoma or TB). · Shortness of breath (large effusions, cardiac failure). · Exposure to asbestos (mesothelioma). · Nephrotic syndrome. Examination · Decreased movement on the affected side. · Tracheal deviation to the opposite side. * Stony dull note on the affected side. · Decreased vocal resonance and diminished breath sounds on the affected side. Proceed as follows: · Comment on aspiration marks. · Percuss for the upper level of effusion in the axilla. · Listen for bronchial breath sounds. · Listen for aegophony at the upper level of the effusion. · It is important to elicit any evidence of an underlying cause, such as clubbing, tar staining, lymph nodes,

In [116]:
data = [{"text" : f"{elem['med_hist']}\n{elem['ques']}", "labels" : elem['ans']} for elem in resp]
# labels = [{"text" : elem['ans']} for elem in resp]

In [117]:
train, test = train_test_split(data, test_size=0.2, random_state=42)

train = Dataset.from_list(train)
test = Dataset.from_list(test)

In [119]:
dataset = DatasetDict({"train" : train, "test" : test})
dataset

DatasetDict({
    train: Dataset({
        features: ['text', 'labels'],
        num_rows: 129
    })
    test: Dataset({
        features: ['text', 'labels'],
        num_rows: 33
    })
})

In [120]:
train_dataset = dataset['train']

In [121]:
context_length = 512
stride = 256

tokenizer = AutoTokenizer.from_pretrained(args.model_name)
tokenizer.pad_token = tokenizer.eos_token

In [107]:
def tokenize(element):
    text_ids = tokenizer(
        element["text"],
        truncation=True,
        # max_length=context_length,
        return_overflowing_tokens=True,
        return_length=True,
        # padding=True,
        # stride=stride
    )

    label_ids = tokenizer(
        element["label"],
        truncation=True,
        # max_length=context_length,
        return_overflowing_tokens=True,
        return_length=True,
        # padding=True,
        # stride=stride
    )

    text_batch = []
    label_batch = []
    for text_id, label_id in zip(text_ids["input_ids"], label_ids["input_ids"]):
        # if length == context_length:
        text_batch.append(text_id)
        label_batch.append(label_id)
    return {"text_ids": text_batch, "label_ids" : label_batch}

In [108]:
tokenized_datasets = dataset.map(
    tokenize, batched=True, remove_columns=dataset["train"].column_names
)

Map:   0%|          | 0/129 [00:00<?, ? examples/s]

Map:   0%|          | 0/33 [00:00<?, ? examples/s]

In [109]:
train_dataset = tokenized_datasets['train']
test_dataset = tokenized_datasets['test']

In [110]:
train_dataset

Dataset({
    features: ['text_ids', 'label_ids'],
    num_rows: 129
})

In [None]:
# if(args.test_data):
#     dataset = load_dataset('text',data_files={'train': args.train_data, 'test': args.test_data})
# else:
#     dataset = load_dataset('text',data_files={'train': args.train_data})

# print(dataset)

In [115]:
# context_length = 512
# stride = 256

tokenizer = AutoTokenizer.from_pretrained(args.model_name)
tokenizer.pad_token = tokenizer.eos_token

def tokenize(element):
    text_ids = tokenizer(
        element["text"],
        return_length=True,
    )

    label_ids = tokenizer(
        element["label"],
        return_length=True,
    )

    text_batch = []
    label_batch = []
    for text_id, label_id in zip(text_ids["input_ids"], label_ids["input_ids"]):
        text_batch.append(text_id)
        label_batch.append(label_id)

    return {"text_ids": text_batch, "label_ids" : label_batch}

tokenized_datasets = dataset.map(
    tokenize, batched=True, remove_columns=dataset["train"].column_names
)
tokenized_datasets.set_format("torch")
print(tokenized_datasets)

os.environ["TOKENIZERS_PARALLELISM"] = 'false'

Map:   0%|          | 0/129 [00:00<?, ? examples/s]

Map:   0%|          | 0/33 [00:00<?, ? examples/s]

DatasetDict({
    train: Dataset({
        features: ['text_ids', 'label_ids'],
        num_rows: 129
    })
    test: Dataset({
        features: ['text_ids', 'label_ids'],
        num_rows: 33
    })
})


In [124]:
len(tokenized_datasets['train'][0]['text_ids'])

313

In [122]:
def causallm_loss(logits, labels):
    
    # Shift so that tokens < n predict n
    # shift_labels = inputs[..., 1:].contiguous()
    # shift_logits = logits[..., :-1, :].contiguous()



    preds = logits.view(-1, logits.size(-1))
    targets = labels.view(-1)
    # targets = targets.clone()
    # targets[:stride-1] = -100

    # Calculate per-token loss
    loss_fct = CrossEntropyLoss(reduction='mean')
    loss = loss_fct(preds, targets)
    # print(loss)
    return loss

In [None]:
# Accelerate training loop

def training_loop(args, dataset, mixed_precision="fp16"):
    
    # model_name = "bloom-1b1"
    model_name = args.model_name
    
    accelerator = Accelerator(mixed_precision = mixed_precision)
    accelerator.print("accelerator initialised")
    
    set_seed(42)
    accelerator.print("seed set")
    model = AutoModelForCausalLM.from_pretrained(f"{model_name}")
    accelerator.print("model loaded")
    
    train_dataloader = DataLoader(dataset['train'], shuffle=False, batch_size=1)
    
    # HYPERPARAMETERS
    
    num_epochs = 30
    warm_up_steps = num_epochs//5 * len(train_dataloader)
    training_steps = 4*num_epochs//5 * len(train_dataloader)
    lr = 1e-5
    checkpoint = True
    load_checkpoint = False
    evaluate = False

    optimizer = AdamW(model.parameters(), lr=lr)

    train_dataloader, model, optimizer = accelerator.prepare(
        train_dataloader, model, optimizer
    )
    
    if(args.test_data):
        test_dataloader = DataLoader(dataset['test'], batch_size=1)
        test_dataloader = accelerator.prepare(test_dataloader)
        
    accelerator.print("dataloaders initialised")

    accelerator.print("scheduler initialised")
    lr_scheduler = get_linear_schedule_with_warmup(
        optimizer=optimizer, num_warmup_steps=warm_up_steps, num_training_steps=training_steps
    )
    
    # Training conditions    
    
    if load_checkpoint:
        model.load_state_dict(torch.load(f'{args.output_dir}/{model_name}_medicine_epoch{epoch}.pth'))

    progress_bar = tqdm(range(training_steps))
    epoch_losses = []
    best = 1
    
    loss_fct = CrossEntropyLoss(reduction='mean')
    model.train()
    accelerator.print("training started")
    for epoch in range(num_epochs):
        step_losses = []
        for step,batch in enumerate(train_dataloader, start = 1):
            # batch = {k: v.to(device) for k, v in batch.items()}
            logits = model(batch['input_ids']).logits
            labels = batch['labels']
            # loss = causallm_loss(logits,batch)

            preds = logits.view(-1, logits.size(-1))
            targets = labels.view(-1)
            # targets = targets.clone()
            # targets[:stride-1] = -100

            loss = loss_fct(preds, targets)
            # loss.backward()
            accelerator.backward(loss)
            step_losses.append([step,loss.item()])

            optimizer.step()
            lr_scheduler.step()
            optimizer.zero_grad()
            progress_bar.update(1)

        epoch_loss = sum(step_losses)/len(step_losses)
        epoch_losses.append(epoch_loss)
        ppl = np.exp(epoch_loss)
        with open(f"{args.output_dir}resp_train_logs.txt","a") as f:
            f.write(f"loss : {epoch_loss:.3f} , perplexity : + {ppl:.3f} \n")

        if epoch_loss < epoch_losses[best-1]:
            best = len(epoch_losses)
        
        if(checkpoint):
            torch.save(model.state_dict(),f'{args.output_dir}/{model_name}_medicine_epoch{epoch}.pth')
                
    accelerator.print("training ended")
    accelerator.print(epoch_losses)
    with open("../model/trained_models/logs.txt","w")as f:
        f.write(f"best = {best}\n" + str(epoch_losses))
    # torch.save(model.state_dict(),f'../model/trained_models/{model_name}_harrison_respiratory.pth')
    accelerator.print("best saved")
    


In [None]:
training_loop(args, tokenized_datasets)

In [130]:
train_dataloader = DataLoader(tokenized_datasets['train'], shuffle=False, batch_size=1)

In [128]:
dataset

DatasetDict({
    train: Dataset({
        features: ['text', 'labels'],
        num_rows: 129
    })
    test: Dataset({
        features: ['text', 'labels'],
        num_rows: 33
    })
})

In [131]:
it = iter(train_dataloader)
sample = next(it)

In [133]:
sample['text_ids'] + sample['label_ids']

tensor([[ 3268, 46126,  2849,  1475,  9862,   428,  5827,   338,  7721,    13,
         42475, 28495, 18630, 47471,  7443, 14128, 41951,   290,  1755, 39387,
            13, 14128,  4434,    64,   786,    11, 18787,    11,   281,   382,
         36072,    13, 14128, 14331,  2994,    13, 14128,   327,   619,   351,
           599,   315,   388,    13, 50105,  2312,     6,  8071,  2334,  4327,
           284,   423,  5895,   286,  2219,  7721, 10040,   543,   389,   407,
          2005,   290, 16577,    13,  1318,   389,  1811,   285,  2812,   329,
           428,    11,   884,   355,  3339,  1523,  6546,  3101,    11, 41899,
           330, 38385,   290, 29631,   261, 42505,    11,  3917,  7375,  2885,
            11,  3917,  7721, 10280,    11,   458,  2381,   496,   393,   872,
           918,   291, 16384, 19813,    13,   383,  1708,  2148,   617,  6096,
            25, 35550,   352,   383,  4540,   373,  1965,   284, 10716,   262,
          7721,   422,   262,  2166,    11,   355,  

In [None]:
sample['']

In [135]:
torch.cat((torch.tensor([[1,2,3]]), torch.tensor([[4,5,6]])), dim = 1)

tensor([[1, 2, 3, 4, 5, 6]])

In [137]:
torch.tensor([[1,2,3]]).shape[1]

3