In [1]:
from datapipe import *

  from .autonotebook import tqdm as notebook_tqdm
Downloading readme: 100%|██████████| 381/381 [00:00<00:00, 1.63MB/s]
Downloading data: 100%|██████████| 107k/107k [00:00<00:00, 175kB/s]
Downloading data files: 100%|██████████| 1/1 [00:00<00:00,  1.32it/s]
Extracting data files: 100%|██████████| 1/1 [00:00<00:00, 397.00it/s]
Generating train split: 400 examples [00:00, 11942.44 examples/s]
vocab.json: 100%|██████████| 1.04M/1.04M [00:00<00:00, 1.63MB/s]
merges.txt: 100%|██████████| 456k/456k [00:00<00:00, 6.01MB/s]
tokenizer.json: 100%|██████████| 1.36M/1.36M [00:00<00:00, 2.37MB/s]
config.json: 100%|██████████| 762/762 [00:00<00:00, 3.27MB/s]

Max length set to:  512





In [5]:
gpu = 0
batch_size = 8
num_epochs = 10
model_name = 'distilgpt2'

In [3]:
model = GPT2LMHeadModel.from_pretrained('distilgpt2').to(device)
criterion = nn.CrossEntropyLoss(ignore_index = tokenizer.pad_token_id)
optimizer = optim.Adam(model.parameters(), lr=5e-4)
tokenizer.pad_token = tokenizer.eos_token

model.safetensors: 100%|██████████| 353M/353M [01:14<00:00, 4.74MB/s] 
generation_config.json: 100%|██████████| 124/124 [00:00<00:00, 72.8kB/s]


In [4]:
# Init a results dataframe
results = pd.DataFrame(columns=['epoch', 'transformer', 'batch_size', 'gpu',
                                'training_loss', 'validation_loss', 'epoch_duration_sec'])

In [6]:
# The training loop
for epoch in range(num_epochs):
    start_time = time.time()  # Start the timer for the epoch

    # Training
    ## This line tells the model we're in 'learning mode'
    model.train()
    epoch_training_loss = 0
    train_iterator = tqdm(train_loader, desc=f"Training Epoch {epoch+1}/{num_epochs} Batch Size: {batch_size}, Transformer: {model_name}")
    for batch in train_iterator:
        optimizer.zero_grad()
        inputs = batch['input_ids'].squeeze(1).to(device)
        targets = inputs.clone()
        outputs = model(input_ids=inputs, labels=targets)
        loss = outputs.loss
        loss.backward()
        optimizer.step()
        train_iterator.set_postfix({'Training Loss': loss.item()})
        epoch_training_loss += loss.item()
    avg_epoch_training_loss = epoch_training_loss / len(train_iterator)

    # Validation
    ## This line below tells the model to 'stop learning'
    model.eval()
    epoch_validation_loss = 0
    total_loss = 0
    valid_iterator = tqdm(valid_loader, desc=f"Validation Epoch {epoch+1}/{num_epochs}")
    with torch.no_grad():
        for batch in valid_iterator:
            inputs = batch['input_ids'].squeeze(1).to(device)
            targets = inputs.clone()
            outputs = model(input_ids=inputs, labels=targets)
            loss = outputs.loss
            total_loss += loss
            valid_iterator.set_postfix({'Validation Loss': loss.item()})
            epoch_validation_loss += loss.item()

    avg_epoch_validation_loss = epoch_validation_loss / len(valid_loader)

    end_time = time.time()  # End the timer for the epoch
    epoch_duration_sec = end_time - start_time  # Calculate the duration in seconds

    new_row = {'transformer': model_name,
               'batch_size': batch_size,
               'gpu': gpu,
               'epoch': epoch+1,
               'training_loss': avg_epoch_training_loss,
               'validation_loss': avg_epoch_validation_loss,
               'epoch_duration_sec': epoch_duration_sec}  # Add epoch_duration to the dataframe

    results.loc[len(results)] = new_row
    print(f"Epoch: {epoch+1}, Validation Loss: {total_loss/len(valid_loader)}")

Training Epoch 1/10 Batch Size: 8, Transformer: distilgpt2: 100%|██████████| 40/40 [00:17<00:00,  2.34it/s, Training Loss=10.3] 
Validation Epoch 1/10: 100%|██████████| 10/10 [00:00<00:00, 12.47it/s, Validation Loss=10] 


Epoch: 1, Validation Loss: 9.955349922180176


Training Epoch 2/10 Batch Size: 8, Transformer: distilgpt2: 100%|██████████| 40/40 [00:13<00:00,  3.06it/s, Training Loss=3.87]
Validation Epoch 2/10: 100%|██████████| 10/10 [00:00<00:00, 12.38it/s, Validation Loss=4.99]


Epoch: 2, Validation Loss: 4.388497352600098


Training Epoch 3/10 Batch Size: 8, Transformer: distilgpt2: 100%|██████████| 40/40 [00:13<00:00,  3.06it/s, Training Loss=9.45]
Validation Epoch 3/10: 100%|██████████| 10/10 [00:00<00:00, 12.55it/s, Validation Loss=9.4]


Epoch: 3, Validation Loss: 8.26036262512207


Training Epoch 4/10 Batch Size: 8, Transformer: distilgpt2: 100%|██████████| 40/40 [00:13<00:00,  2.97it/s, Training Loss=7.46]
Validation Epoch 4/10: 100%|██████████| 10/10 [00:00<00:00, 12.42it/s, Validation Loss=7.86]


Epoch: 4, Validation Loss: 6.821663856506348


Training Epoch 5/10 Batch Size: 8, Transformer: distilgpt2: 100%|██████████| 40/40 [00:13<00:00,  3.01it/s, Training Loss=6.29]
Validation Epoch 5/10: 100%|██████████| 10/10 [00:00<00:00, 12.40it/s, Validation Loss=10.2]


Epoch: 5, Validation Loss: 8.952461242675781


Training Epoch 6/10 Batch Size: 8, Transformer: distilgpt2: 100%|██████████| 40/40 [00:13<00:00,  3.04it/s, Training Loss=11.5]
Validation Epoch 6/10: 100%|██████████| 10/10 [00:00<00:00, 11.01it/s, Validation Loss=10.9]


Epoch: 6, Validation Loss: 9.536514282226562


Training Epoch 7/10 Batch Size: 8, Transformer: distilgpt2: 100%|██████████| 40/40 [00:13<00:00,  2.94it/s, Training Loss=9.18]
Validation Epoch 7/10: 100%|██████████| 10/10 [00:00<00:00, 10.52it/s, Validation Loss=11.8]


Epoch: 7, Validation Loss: 10.26335334777832


Training Epoch 8/10 Batch Size: 8, Transformer: distilgpt2: 100%|██████████| 40/40 [00:13<00:00,  2.86it/s, Training Loss=12.9]
Validation Epoch 8/10: 100%|██████████| 10/10 [00:00<00:00, 11.04it/s, Validation Loss=12.1]


Epoch: 8, Validation Loss: 10.555124282836914


Training Epoch 9/10 Batch Size: 8, Transformer: distilgpt2: 100%|██████████| 40/40 [00:13<00:00,  2.91it/s, Training Loss=7.19]
Validation Epoch 9/10: 100%|██████████| 10/10 [00:00<00:00, 10.93it/s, Validation Loss=11.8]


Epoch: 9, Validation Loss: 10.212566375732422


Training Epoch 10/10 Batch Size: 8, Transformer: distilgpt2: 100%|██████████| 40/40 [00:13<00:00,  2.96it/s, Training Loss=8.84]
Validation Epoch 10/10: 100%|██████████| 10/10 [00:00<00:00, 10.69it/s, Validation Loss=10.8]


Epoch: 10, Validation Loss: 9.4764404296875


In [None]:
input_str = "Kidney Failure"
input_ids = tokenizer.encode(input_str, return_tensors='pt').to(device)

output = model.generate(
                        input_ids,
                        max_length=20,
                        num_return_sequences=1,
                        do_sample=True,
                        top_k=8,
                        top_p=0.95,
                        temperature=0.5,
                        repetition_penalty=1.2
                        )

decoded_output = tokenizer.decode(output[0], skip_special_tokens=True)
print(decoded_output)

In [None]:
torch.save(model, 'SmallMedLM.pt')