In [4]:
!pip install torch transformers tqdm

  pid, fd = os.forkpty()




In [5]:
import torch
from torch.utils.data import DataLoader, TensorDataset
from torch.nn import CrossEntropyLoss
from transformers import GPT2LMHeadModel, GPT2Tokenizer
from tqdm import tqdm
import json


In [6]:
# Constants
MODEL_NAME = "gpt2"
# MODEL_NAME = "facebook/opt-350m"
PROMPT_TOKEN = "[GENERATE] [JSON] [OBJECT] [MODEL] [FORMAT] [KEY] [VALUE] [FIELD]"
MAX_LEN = 1024

# Hyperparameters
BATCH_SIZE = 1
EPOCHS = 100
GRADIENT_ACCUMULATION_STEPS = 1
GRADIENT_CLIP_NORM = 1.0
EARLY_STOPPING_PATIENCE = 2

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Soft Prompt Vocabulary
soft_prompt_vocab = ["[GENERATE]", "[JSON]", "[OBJECT]", "[MODEL]", "[FORMAT]", "[KEY]", "[VALUE]", "[FIELD]"]  # Define your custom vocabulary here

# Create a word2idx dictionary for the soft prompt vocabulary
soft_prompt_word2idx = {word: idx for idx, word in enumerate(soft_prompt_vocab)}

num_prompts = len([soft_prompt_word2idx[word] for word in PROMPT_TOKEN.split()])
prompt_id = torch.tensor([soft_prompt_word2idx[word] for word in PROMPT_TOKEN.split()])
prompt_id = prompt_id.to(device)


In [7]:
# Model Architecture
class GPT2WithSoftPrompt(torch.nn.Module):
    def __init__(self, model_name, num_prompts, embedding_size=768):
        super().__init__()
        self.gpt2 = GPT2LMHeadModel.from_pretrained(model_name)
        self.soft_prompt = torch.nn.Embedding(num_prompts, embedding_size)

    def forward(self, input_ids, prompt_ids):
        prompt_embeddings = self.soft_prompt(prompt_ids)
        base_embeddings = self.gpt2.transformer.wte(input_ids)
        embeddings = torch.cat([prompt_embeddings, base_embeddings.squeeze(0)], dim=0)
        outputs = self.gpt2(inputs_embeds=embeddings)
        return outputs


In [8]:
# Data Loading and Preprocessing
def load_and_preprocess_data(file_path, num_prompts):
    file = open(file_path, "r")
    
    data = json.load(file)
    tokenized_inputs = []
    tokenized_outputs = []

    tokenizer = GPT2Tokenizer.from_pretrained(MODEL_NAME)

    for item in data:
        # Adjust the maximum length of articles to avoid exceeding MAX_LEN
        max_length_article = MAX_LEN - num_prompts 
        output_tokens = tokenizer.encode(json.dumps(item["output"]), truncation=True, max_length=max_length_article)
        input_tokens = tokenizer.encode(item["input"], truncation=True, max_length=300)

        max_length_summary = MAX_LEN
        padded_input = input_tokens + [tokenizer.eos_token_id] * (max_length_article - len(input_tokens))
        padded_output = output_tokens + [tokenizer.eos_token_id] * (max_length_summary - len(output_tokens))

        tokenized_inputs.append(padded_input)
        tokenized_outputs.append(padded_output)

    file.close()
    
    train_limit = int(len(tokenized_inputs) * 0.7)
    val_limit = int(len(tokenized_inputs) * 0.9)

    return tokenized_inputs[:train_limit], tokenized_outputs[:train_limit], tokenized_inputs[train_limit:val_limit], tokenized_outputs[train_limit:val_limit], tokenized_inputs[val_limit:], tokenized_outputs[val_limit:]


In [9]:
# Load and preprocess the data
tokenizer = GPT2Tokenizer.from_pretrained(MODEL_NAME)
tokenized_inputs_train, tokenized_outputs_train, tokenized_inputs_validation, tokenized_outputs_validation, tokenized_inputs_test, tokenized_outputs_test = load_and_preprocess_data("/kaggle/input/json-dataset/dataset.json", num_prompts)

# Model Initialization
model = GPT2WithSoftPrompt(MODEL_NAME, num_prompts).to(device)


tokenizer_config.json:   0%|          | 0.00/26.0 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

config.json:   0%|          | 0.00/665 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/548M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

In [10]:
len(tokenized_inputs_train[0])

1016

In [11]:
len(tokenized_outputs_train[0])

1024

In [12]:
def fine_tune_on_summarization(model, train_inputs, train_outputs, val_inputs, val_outputs, test_inputs, test_outputs):
    optimizer = torch.optim.Adam(model.soft_prompt.parameters())

    best_val_loss = float('inf')
    no_improvement_epochs = 0

    for epoch in range(EPOCHS):
        model.train()

        # Gradient accumulation initialization
        optimizer.zero_grad()

        accumulated_loss = 0
        loss = 0

        # Use tqdm for progress bar
        with tqdm(enumerate(zip(train_inputs, train_outputs)), total=len(train_inputs), desc=f"Epoch {epoch + 1}/{EPOCHS}", unit="batch") as progress:
            train_percentage_matched = 0
            train_percentage_matched_ct = 0

            for idx, (input, output) in progress:
                input_ids = torch.tensor(input).to(device)
                labels = torch.tensor(output).to(device)
                outputs = model(input_ids, prompt_id)

                ignore_index = tokenizer.eos_token_id
                loss += CrossEntropyLoss(ignore_index=ignore_index)(outputs.logits, labels)

                # Metrics
                set1 = set(torch.argmax(outputs.logits, dim=1).cpu().numpy())
                set2 = set(labels.cpu().numpy())

                # Calculate the intersection of sets
                intersection = set1.intersection(set2)

                # Calculate the percentage of indices in the first tensor that are also in the second tensor
                percentage = (len(intersection) / len(set1)) * 100
                train_percentage_matched += percentage
                train_percentage_matched_ct += 1

                # Backpropagate losses every GRADIENT_ACCUMULATION_STEPS or at the end of the dataset
                if (idx + 1) % GRADIENT_ACCUMULATION_STEPS == 0 or idx == len(train_inputs) - 1:
                    (loss / GRADIENT_ACCUMULATION_STEPS).backward()
                    torch.nn.utils.clip_grad_norm_(model.parameters(), GRADIENT_CLIP_NORM)
                    optimizer.step()
                    optimizer.zero_grad()
                    loss = 0
            
            print("Train : % Exact Match: ",train_percentage_matched/train_percentage_matched_ct)

        # Validation
        model.eval()
        total_val_loss = 0

        with torch.no_grad():
            val_percentage_matched = 0
            val_percentage_matched_ct = 0

            for input, output in tqdm(zip(val_inputs, val_outputs), total=len(val_inputs), desc="Validation", unit="batch"):
                input_ids = torch.tensor(input).to(device)
                labels = torch.tensor(output).to(device)
                outputs = model(input_ids, prompt_id)

                ignore_index = tokenizer.eos_token_id if tokenizer.eos_token_id is not None else -100
                val_loss = CrossEntropyLoss(ignore_index=ignore_index)(outputs.logits, labels)
                total_val_loss += val_loss.item()

                # Metrics
                set1 = set(torch.argmax(outputs.logits, dim=1).cpu().numpy())
                set2 = set(labels.cpu().numpy())

                # Calculate the intersection of sets
                intersection = set1.intersection(set2)

                # Calculate the percentage of indices in the first tensor that are also in the second tensor
                percentage = (len(intersection) / len(set1)) * 100
                val_percentage_matched += percentage
                val_percentage_matched_ct += 1

        print("Val : % Exact Match: ",val_percentage_matched/val_percentage_matched_ct)
        avg_val_loss = total_val_loss / len(val_inputs)
        print("Val Loss : ",avg_val_loss)

        # Early stopping
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            no_improvement_epochs = 0
        else:
            no_improvement_epochs += 1
            if no_improvement_epochs >= EARLY_STOPPING_PATIENCE:
                print(f"Early stopping after {EARLY_STOPPING_PATIENCE} epochs without improvement.")
                break


    # Testing
    model.eval()
    total_test_loss = 0

    with torch.no_grad():
        test_percentage_matched = 0
        test_percentage_matched_ct = 0

        for input, output in tqdm(zip(test_inputs, test_outputs), total=len(test_inputs), desc="Test", unit="batch"):
            input_ids = torch.tensor(input).to(device)
            labels = torch.tensor(output).to(device)
            outputs = model(input_ids, prompt_id)

            ignore_index = tokenizer.eos_token_id if tokenizer.eos_token_id is not None else -100
            test_loss = CrossEntropyLoss(ignore_index=ignore_index)(outputs.logits, labels)
            total_test_loss += test_loss.item()

            # Metrics
            set1 = set(torch.argmax(outputs.logits, dim=1).cpu().numpy())
            set2 = set(labels.cpu().numpy())

            # Calculate the intersection of sets
            intersection = set1.intersection(set2)

            # Calculate the percentage of indices in the first tensor that are also in the second tensor
            percentage = (len(intersection) / len(set1)) * 100
            test_percentage_matched += percentage
            test_percentage_matched_ct += 1
        
        
        print("Test : % Exact Match: ",test_percentage_matched/test_percentage_matched_ct)
        avg_test_loss = total_test_loss / len(test_inputs)
        print("Test Loss : ",avg_test_loss)
        print("")


    return model


In [13]:
fine_tuned_model = fine_tune_on_summarization(model, tokenized_inputs_train, tokenized_outputs_train, tokenized_inputs_validation, tokenized_outputs_validation, tokenized_inputs_test, tokenized_outputs_test)


Epoch 1/100: 100%|██████████| 219/219 [01:00<00:00,  3.63batch/s]


Train : % Exact Match:  11.7906166180054


Validation: 100%|██████████| 62/62 [00:05<00:00, 10.68batch/s]


Val : % Exact Match:  18.33029597537911
Val Loss :  10.20312052388345


Epoch 2/100: 100%|██████████| 219/219 [01:04<00:00,  3.39batch/s]


Train : % Exact Match:  12.238617174067748


Validation: 100%|██████████| 62/62 [00:06<00:00,  9.81batch/s]


Val : % Exact Match:  17.25844274293789
Val Loss :  9.510851798519012


Epoch 3/100: 100%|██████████| 219/219 [01:06<00:00,  3.28batch/s]


Train : % Exact Match:  12.47101623034656


Validation: 100%|██████████| 62/62 [00:06<00:00, 10.02batch/s]


Val : % Exact Match:  16.583403155983795
Val Loss :  9.065150753144295


Epoch 4/100: 100%|██████████| 219/219 [01:06<00:00,  3.29batch/s]


Train : % Exact Match:  12.21869909420271


Validation: 100%|██████████| 62/62 [00:06<00:00, 10.05batch/s]


Val : % Exact Match:  14.64771787352432
Val Loss :  8.903754495805309


Epoch 5/100: 100%|██████████| 219/219 [01:06<00:00,  3.30batch/s]


Train : % Exact Match:  13.938403251987372


Validation: 100%|██████████| 62/62 [00:06<00:00, 10.01batch/s]


Val : % Exact Match:  14.318781576846092
Val Loss :  8.707564446233935


Epoch 6/100: 100%|██████████| 219/219 [01:06<00:00,  3.29batch/s]


Train : % Exact Match:  18.790590471694422


Validation: 100%|██████████| 62/62 [00:06<00:00,  9.99batch/s]


Val : % Exact Match:  25.36279133053326
Val Loss :  8.507593539453321


Epoch 7/100: 100%|██████████| 219/219 [01:06<00:00,  3.30batch/s]


Train : % Exact Match:  21.4263839596763


Validation: 100%|██████████| 62/62 [00:06<00:00, 10.00batch/s]


Val : % Exact Match:  25.492764930620712
Val Loss :  8.288107718190838


Epoch 8/100: 100%|██████████| 219/219 [01:06<00:00,  3.30batch/s]


Train : % Exact Match:  23.933052196761416


Validation: 100%|██████████| 62/62 [00:06<00:00, 10.01batch/s]


Val : % Exact Match:  27.372295386626753
Val Loss :  8.020360592872866


Epoch 9/100: 100%|██████████| 219/219 [01:06<00:00,  3.30batch/s]


Train : % Exact Match:  27.270510163453746


Validation: 100%|██████████| 62/62 [00:06<00:00,  9.97batch/s]


Val : % Exact Match:  33.47479300746305
Val Loss :  7.799670411694434


Epoch 10/100: 100%|██████████| 219/219 [01:06<00:00,  3.29batch/s]


Train : % Exact Match:  30.434922564794444


Validation: 100%|██████████| 62/62 [00:06<00:00,  9.95batch/s]


Val : % Exact Match:  33.14604526298074
Val Loss :  7.433152298773488


Epoch 11/100: 100%|██████████| 219/219 [01:06<00:00,  3.29batch/s]


Train : % Exact Match:  33.64147570522943


Validation: 100%|██████████| 62/62 [00:06<00:00,  9.95batch/s]


Val : % Exact Match:  39.65439130761711
Val Loss :  7.209127272329023


Epoch 12/100: 100%|██████████| 219/219 [01:06<00:00,  3.29batch/s]


Train : % Exact Match:  35.34206874133288


Validation: 100%|██████████| 62/62 [00:06<00:00,  9.98batch/s]


Val : % Exact Match:  40.90342005664586
Val Loss :  7.031517736373409


Epoch 13/100: 100%|██████████| 219/219 [01:06<00:00,  3.29batch/s]


Train : % Exact Match:  35.88338730318878


Validation: 100%|██████████| 62/62 [00:06<00:00,  9.99batch/s]


Val : % Exact Match:  43.76472094214028
Val Loss :  6.848230392702164


Epoch 14/100: 100%|██████████| 219/219 [01:06<00:00,  3.30batch/s]


Train : % Exact Match:  37.73788940489176


Validation: 100%|██████████| 62/62 [00:06<00:00,  9.99batch/s]


Val : % Exact Match:  30.574264445232178
Val Loss :  6.646610060045796


Epoch 15/100: 100%|██████████| 219/219 [01:06<00:00,  3.29batch/s]


Train : % Exact Match:  38.04727020384557


Validation: 100%|██████████| 62/62 [00:06<00:00,  9.94batch/s]


Val : % Exact Match:  28.879722428109517
Val Loss :  6.536799384701636


Epoch 16/100: 100%|██████████| 219/219 [01:06<00:00,  3.29batch/s]


Train : % Exact Match:  38.90227813506999


Validation: 100%|██████████| 62/62 [00:06<00:00,  9.98batch/s]


Val : % Exact Match:  29.43666548505257
Val Loss :  6.462057744303057


Epoch 17/100: 100%|██████████| 219/219 [01:06<00:00,  3.30batch/s]


Train : % Exact Match:  38.86832818652441


Validation: 100%|██████████| 62/62 [00:06<00:00,  9.96batch/s]


Val : % Exact Match:  26.838630613535745
Val Loss :  6.394574619108631


Epoch 18/100: 100%|██████████| 219/219 [01:06<00:00,  3.29batch/s]


Train : % Exact Match:  41.345839714992316


Validation: 100%|██████████| 62/62 [00:06<00:00,  9.97batch/s]


Val : % Exact Match:  25.13507326007324
Val Loss :  6.359844053945234


Epoch 19/100: 100%|██████████| 219/219 [01:06<00:00,  3.30batch/s]


Train : % Exact Match:  41.54452543559373


Validation: 100%|██████████| 62/62 [00:06<00:00,  9.94batch/s]


Val : % Exact Match:  38.71314288178616
Val Loss :  6.327132878764983


Epoch 20/100: 100%|██████████| 219/219 [01:06<00:00,  3.28batch/s]


Train : % Exact Match:  43.06228028168101


Validation: 100%|██████████| 62/62 [00:06<00:00,  9.93batch/s]


Val : % Exact Match:  38.6166253101737
Val Loss :  6.300835094144268


Epoch 21/100: 100%|██████████| 219/219 [01:06<00:00,  3.29batch/s]


Train : % Exact Match:  42.72557237557872


Validation: 100%|██████████| 62/62 [00:06<00:00,  9.97batch/s]


Val : % Exact Match:  37.798062152900876
Val Loss :  6.269700127263223


Epoch 22/100: 100%|██████████| 219/219 [01:06<00:00,  3.30batch/s]


Train : % Exact Match:  42.83895863511198


Validation: 100%|██████████| 62/62 [00:06<00:00,  9.96batch/s]


Val : % Exact Match:  36.849521446295654
Val Loss :  6.237122235759612


Epoch 23/100: 100%|██████████| 219/219 [01:06<00:00,  3.29batch/s]


Train : % Exact Match:  42.98330573559592


Validation: 100%|██████████| 62/62 [00:06<00:00,  9.96batch/s]


Val : % Exact Match:  37.78469514356612
Val Loss :  6.210632908728815


Epoch 24/100: 100%|██████████| 219/219 [01:06<00:00,  3.29batch/s]


Train : % Exact Match:  41.85096150854874


Validation: 100%|██████████| 62/62 [00:06<00:00,  9.98batch/s]


Val : % Exact Match:  39.048652959943325
Val Loss :  6.195588173404817


Epoch 25/100: 100%|██████████| 219/219 [01:06<00:00,  3.30batch/s]


Train : % Exact Match:  42.921311360759596


Validation: 100%|██████████| 62/62 [00:06<00:00,  9.93batch/s]


Val : % Exact Match:  38.426331732783346
Val Loss :  6.204828654566119


Epoch 26/100: 100%|██████████| 219/219 [01:06<00:00,  3.29batch/s]


Train : % Exact Match:  42.509137650830596


Validation: 100%|██████████| 62/62 [00:06<00:00,  9.93batch/s]


Val : % Exact Match:  38.8481545336384
Val Loss :  6.170154448478453


Epoch 27/100: 100%|██████████| 219/219 [01:06<00:00,  3.29batch/s]


Train : % Exact Match:  42.453235085344126


Validation: 100%|██████████| 62/62 [00:06<00:00,  9.95batch/s]


Val : % Exact Match:  39.51298701298702
Val Loss :  6.159498145503383


Epoch 28/100: 100%|██████████| 219/219 [01:06<00:00,  3.28batch/s]


Train : % Exact Match:  42.9109562815933


Validation: 100%|██████████| 62/62 [00:06<00:00,  9.91batch/s]


Val : % Exact Match:  37.726390465954054
Val Loss :  6.141156435012817


Epoch 29/100: 100%|██████████| 219/219 [01:06<00:00,  3.28batch/s]


Train : % Exact Match:  42.81106757208373


Validation: 100%|██████████| 62/62 [00:06<00:00,  9.96batch/s]


Val : % Exact Match:  36.682409936679406
Val Loss :  6.1450788513306644


Epoch 30/100: 100%|██████████| 219/219 [01:06<00:00,  3.30batch/s]


Train : % Exact Match:  41.65031549918521


Validation: 100%|██████████| 62/62 [00:06<00:00,  9.99batch/s]


Val : % Exact Match:  36.45815087821731
Val Loss :  6.142245631064138
Early stopping after 2 epochs without improvement.


Test: 100%|██████████| 32/32 [00:03<00:00, 10.25batch/s]

Test : % Exact Match:  34.22540771116138
Test Loss :  7.36285375058651






# Saving Model

In [14]:
# Save the fine-tuned model
torch.save(fine_tuned_model.state_dict(), 'fine_tuned_model.pth')


# Loading Model

In [15]:
# Initialize a new instance of the model
model = GPT2WithSoftPrompt(MODEL_NAME, num_prompts).to(device)

# Load the saved model state_dict
model.load_state_dict(torch.load('fine_tuned_model.pth'))

# Make sure the model is in evaluation mode after loading
model.eval()

  model.load_state_dict(torch.load('fine_tuned_model.pth'))


GPT2WithSoftPrompt(
  (gpt2): GPT2LMHeadModel(
    (transformer): GPT2Model(
      (wte): Embedding(50257, 768)
      (wpe): Embedding(1024, 768)
      (drop): Dropout(p=0.1, inplace=False)
      (h): ModuleList(
        (0-11): 12 x GPT2Block(
          (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (attn): GPT2SdpaAttention(
            (c_attn): Conv1D(nf=2304, nx=768)
            (c_proj): Conv1D(nf=768, nx=768)
            (attn_dropout): Dropout(p=0.1, inplace=False)
            (resid_dropout): Dropout(p=0.1, inplace=False)
          )
          (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (mlp): GPT2MLP(
            (c_fc): Conv1D(nf=3072, nx=768)
            (c_proj): Conv1D(nf=768, nx=3072)
            (act): NewGELUActivation()
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
      )
      (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
    )
    (lm_head): Linear(in_features=768, out

# Inference

In [16]:
# Set the model to evaluation mode
model.eval()

# Input text for summarization
input_text = "Transform into JSON including 'multi_cloud_controller', 'orchestration_policies', 'migration_strategies', and 'cost_optimizations': 'CloudHarmonizer managed 2 orchestration policies, used live migration and backup-restore strategies, optimizing costs by 15%.'"

# Tokenize and encode the input text
input_ids = tokenizer.encode(input_text, truncation=True, max_length=1024)

# Convert the input_ids to a PyTorch tensor
input_ids = torch.tensor(input_ids)

# Generate a summary
with torch.no_grad():
    # Assuming single prompt
    outputs = model(input_ids.to(device), prompt_ids=prompt_id.to(device))
    pred_logits = outputs.logits
    print(pred_logits.shape)


# Get the token IDs with the highest probability for each position
predicted_token_ids = torch.argmax(pred_logits, dim=-1)

# Convert token IDs into words using the tokenizer
predicted_tokens = tokenizer.decode(predicted_token_ids.squeeze(0), skip_special_tokens=True)


torch.Size([71, 50257])


In [17]:
predicted_tokens

'{"mer": "":",",":":":": "":":":":":": "":":":":":ity": " "":":":":":": "": "":":":": " "":":":":":": "": "": "": "": "": "": "": "":": " " " "'

In [18]:
# Set the model to evaluation mode
model.eval()

# Input text for summarization
input_text = "Convert the following sentence into a JSON object with clear key-value pairs: 'I bought 2 flowers and a flower pot.'"

# Tokenize and encode the input text
input_ids = tokenizer.encode(input_text, truncation=True, max_length=1024)

# Convert the input_ids to a PyTorch tensor
input_ids = torch.tensor(input_ids)

# Generate a summary
with torch.no_grad():
    # Assuming single prompt
    outputs = model(input_ids.to(device), prompt_ids=prompt_id.to(device))
    pred_logits = outputs.logits
    print(pred_logits.shape)


# Get the token IDs with the highest probability for each position
predicted_token_ids = torch.argmax(pred_logits, dim=-1)

# Convert token IDs into words using the tokenizer
predicted_tokens = tokenizer.decode(predicted_token_ids.squeeze(0), skip_special_tokens=True)


torch.Size([34, 50257])


In [19]:
predicted_tokens

'{"mer": "":",",":":":":":":":":":":":":":":": " "":": "":":":": "":'