In [None]:
import torch
import json
from torch.utils.data import DataLoader, TensorDataset
from transformers import BertTokenizer, BertForSequenceClassification, AdamW, get_linear_schedule_with_warmup
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report
from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import accuracy_score
from tqdm import tqdm


In [None]:
# Load dataset
with open('data.json', 'r') as file:
    synthetic_data = json.load(file)

synthetic_sentences = [item['text'] for item in synthetic_data]
synthetic_labels = [item['labels'][0]['action'] if 'labels' in item and item['labels'] else 'no_action' for item in synthetic_data]


In [None]:

# Use label encoding to convert string labels to numerical labels
label_encoder = LabelEncoder()
synthetic_labels = label_encoder.fit_transform(synthetic_labels)


In [None]:

# Tokenize the synthetic sentences using BERT tokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
tokenized_synthetic_inputs = tokenizer(synthetic_sentences, padding=True, truncation=True, return_tensors='pt')


tokenizer_config.json:   0%|          | 0.00/28.0 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

In [None]:

# Create PyTorch Dataset and DataLoader
synthetic_input_ids = tokenized_synthetic_inputs['input_ids']
synthetic_attention_mask = tokenized_synthetic_inputs['attention_mask']
synthetic_labels = torch.tensor(synthetic_labels)

synthetic_dataset = TensorDataset(synthetic_input_ids, synthetic_attention_mask, synthetic_labels)
synthetic_train_dataset, synthetic_val_dataset = train_test_split(
    synthetic_dataset, test_size=0.2, random_state=42
)


In [None]:

# Adjust batch size
synthetic_train_dataloader = DataLoader(synthetic_train_dataset, batch_size=8, shuffle=True)
synthetic_val_dataloader = DataLoader(synthetic_val_dataset, batch_size=8, shuffle=False)



In [None]:

# Initialize BERT model for sequence classification
model = BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=len(set(synthetic_labels)))


In [None]:

# Define optimizer and learning rate scheduler
optimizer = torch.optim.AdamW(model.parameters(), lr=7e-5)  # Adjust learning rate
total_steps = len(synthetic_train_dataloader) * 10  # Adjust number of epochs
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=0, num_training_steps=total_steps)


In [None]:

# Training loop
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

for epoch in range(10):
    model.train()
    total_loss = 0
    for batch in tqdm(synthetic_train_dataloader, desc=f'Epoch {epoch + 1}'):
        inputs = {'input_ids': batch[0].to(device),
                  'attention_mask': batch[1].to(device),
                  'labels': batch[2].to(device)}
        optimizer.zero_grad()
        outputs = model(**inputs)
        loss = outputs.loss
        total_loss += loss.item()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        scheduler.step()

    average_loss = total_loss / len(synthetic_train_dataloader)
    print(f'Training Loss: {average_loss}')

    # Validation loop
    model.eval()
    val_predictions, val_labels = [], []
    with torch.no_grad():
        for batch in tqdm(synthetic_val_dataloader, desc=f'Validation Epoch {epoch + 1}'):
            inputs = {'input_ids': batch[0].to(device),
                      'attention_mask': batch[1].to(device),
                      'labels': batch[2].to(device)}
            outputs = model(**inputs)
            logits = outputs.logits
            predictions = torch.argmax(logits, dim=1).cpu().numpy()
            labels_batch = batch[2].cpu().numpy()
            val_predictions.extend(predictions)
            val_labels.extend(labels_batch)

    val_accuracy = accuracy_score(val_labels, val_predictions)
    print(f'Validation Accuracy after Epoch {epoch + 1}: {val_accuracy}')
    print(classification_report(val_labels, val_predictions))

Epoch 1: 100%|██████████| 9/9 [00:33<00:00,  3.74s/it]


Training Loss: 4.074355125427246


Validation Epoch 1: 100%|██████████| 3/3 [00:02<00:00,  1.39it/s]
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Validation Accuracy after Epoch 1: 0.21052631578947367
              precision    recall  f1-score   support

           0       0.00      0.00      0.00         1
           1       0.00      0.00      0.00         7
           2       0.00      0.00      0.00         1
           3       0.00      0.00      0.00         2
           4       0.00      0.00      0.00         1
           5       0.19      1.00      0.32         3
           6       0.00      0.00      0.00         1
           7       0.00      0.00      0.00         1
           9       1.00      1.00      1.00         1
          11       0.00      0.00      0.00         1

    accuracy                           0.21        19
   macro avg       0.12      0.20      0.13        19
weighted avg       0.08      0.21      0.10        19



Epoch 2: 100%|██████████| 9/9 [00:24<00:00,  2.75s/it]


Training Loss: 3.203084389368693


Validation Epoch 2: 100%|██████████| 3/3 [00:01<00:00,  1.91it/s]
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Validation Accuracy after Epoch 2: 0.2631578947368421
              precision    recall  f1-score   support

           0       0.00      0.00      0.00         1
           1       0.00      0.00      0.00         7
           2       0.00      0.00      0.00         1
           3       0.00      0.00      0.00         2
           4       0.00      0.00      0.00         1
           5       0.19      1.00      0.32         3
           6       1.00      1.00      1.00         1
           7       0.00      0.00      0.00         1
           9       1.00      1.00      1.00         1
          11       0.00      0.00      0.00         1

    accuracy                           0.26        19
   macro avg       0.22      0.30      0.23        19
weighted avg       0.13      0.26      0.16        19



Epoch 3: 100%|██████████| 9/9 [00:33<00:00,  3.67s/it]


Training Loss: 2.635454840130276


Validation Epoch 3: 100%|██████████| 3/3 [00:01<00:00,  1.87it/s]
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Validation Accuracy after Epoch 3: 0.631578947368421
              precision    recall  f1-score   support

           0       0.00      0.00      0.00         1
           1       0.88      1.00      0.93         7
           2       0.00      0.00      0.00         1
           3       0.00      0.00      0.00         2
           4       0.00      0.00      0.00         1
           5       0.43      1.00      0.60         3
           6       0.00      0.00      0.00         1
           7       1.00      1.00      1.00         1
           9       0.33      1.00      0.50         1
          11       0.00      0.00      0.00         1

    accuracy                           0.63        19
   macro avg       0.26      0.40      0.30        19
weighted avg       0.46      0.63      0.52        19



Epoch 4: 100%|██████████| 9/9 [00:25<00:00,  2.82s/it]


Training Loss: 2.1152016984091864


Validation Epoch 4: 100%|██████████| 3/3 [00:02<00:00,  1.39it/s]
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Validation Accuracy after Epoch 4: 0.7368421052631579
              precision    recall  f1-score   support

           0       0.00      0.00      0.00         1
           1       0.78      1.00      0.88         7
           2       0.00      0.00      0.00         1
           3       0.00      0.00      0.00         2
           4       1.00      1.00      1.00         1
           5       0.75      1.00      0.86         3
           6       1.00      1.00      1.00         1
           7       1.00      1.00      1.00         1
           9       0.33      1.00      0.50         1
          11       0.00      0.00      0.00         1

    accuracy                           0.74        19
   macro avg       0.49      0.60      0.52        19
weighted avg       0.58      0.74      0.64        19



Epoch 5: 100%|██████████| 9/9 [00:25<00:00,  2.85s/it]


Training Loss: 1.6716329389148288


Validation Epoch 5: 100%|██████████| 3/3 [00:02<00:00,  1.41it/s]
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Validation Accuracy after Epoch 5: 0.7368421052631579
              precision    recall  f1-score   support

           0       0.00      0.00      0.00         1
           1       0.78      1.00      0.88         7
           2       0.00      0.00      0.00         1
           3       0.00      0.00      0.00         2
           4       1.00      1.00      1.00         1
           5       1.00      1.00      1.00         3
           6       1.00      1.00      1.00         1
           7       1.00      1.00      1.00         1
           9       0.25      1.00      0.40         1
          11       0.00      0.00      0.00         1

    accuracy                           0.74        19
   macro avg       0.50      0.60      0.53        19
weighted avg       0.62      0.74      0.66        19



Epoch 6: 100%|██████████| 9/9 [00:25<00:00,  2.82s/it]


Training Loss: 1.3593677414788141


Validation Epoch 6: 100%|██████████| 3/3 [00:01<00:00,  2.24it/s]
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Validation Accuracy after Epoch 6: 0.8421052631578947
              precision    recall  f1-score   support

           0       0.00      0.00      0.00         1
           1       0.78      1.00      0.88         7
           2       0.00      0.00      0.00         1
           3       1.00      1.00      1.00         2
           4       1.00      1.00      1.00         1
           5       1.00      1.00      1.00         3
           6       1.00      1.00      1.00         1
           7       1.00      1.00      1.00         1
           9       0.50      1.00      0.67         1
          11       0.00      0.00      0.00         1

    accuracy                           0.84        19
   macro avg       0.63      0.70      0.65        19
weighted avg       0.73      0.84      0.78        19



Epoch 7: 100%|██████████| 9/9 [00:27<00:00,  3.08s/it]


Training Loss: 1.1551348898145888


Validation Epoch 7: 100%|██████████| 3/3 [00:01<00:00,  2.22it/s]
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Validation Accuracy after Epoch 7: 0.8421052631578947
              precision    recall  f1-score   support

           0       0.00      0.00      0.00         1
           1       0.78      1.00      0.88         7
           2       0.00      0.00      0.00         1
           3       1.00      1.00      1.00         2
           4       0.50      1.00      0.67         1
           5       1.00      1.00      1.00         3
           6       1.00      1.00      1.00         1
           7       1.00      1.00      1.00         1
           9       1.00      1.00      1.00         1
          11       0.00      0.00      0.00         1

    accuracy                           0.84        19
   macro avg       0.63      0.70      0.65        19
weighted avg       0.73      0.84      0.78        19



Epoch 8: 100%|██████████| 9/9 [00:26<00:00,  2.97s/it]


Training Loss: 1.0334804985258315


Validation Epoch 8: 100%|██████████| 3/3 [00:01<00:00,  2.24it/s]
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Validation Accuracy after Epoch 8: 0.8421052631578947
              precision    recall  f1-score   support

           0       0.00      0.00      0.00         1
           1       0.78      1.00      0.88         7
           2       0.00      0.00      0.00         1
           3       1.00      1.00      1.00         2
           4       0.50      1.00      0.67         1
           5       1.00      1.00      1.00         3
           6       1.00      1.00      1.00         1
           7       1.00      1.00      1.00         1
           9       1.00      1.00      1.00         1
          11       0.00      0.00      0.00         1

    accuracy                           0.84        19
   macro avg       0.63      0.70      0.65        19
weighted avg       0.73      0.84      0.78        19



Epoch 9: 100%|██████████| 9/9 [00:26<00:00,  2.93s/it]


Training Loss: 0.9351864059766134


Validation Epoch 9: 100%|██████████| 3/3 [00:01<00:00,  2.25it/s]
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Validation Accuracy after Epoch 9: 0.8421052631578947
              precision    recall  f1-score   support

           0       0.00      0.00      0.00         1
           1       0.78      1.00      0.88         7
           2       0.00      0.00      0.00         1
           3       1.00      1.00      1.00         2
           4       0.50      1.00      0.67         1
           5       1.00      1.00      1.00         3
           6       1.00      1.00      1.00         1
           7       1.00      1.00      1.00         1
           9       1.00      1.00      1.00         1
          11       0.00      0.00      0.00         1

    accuracy                           0.84        19
   macro avg       0.63      0.70      0.65        19
weighted avg       0.73      0.84      0.78        19



Epoch 10: 100%|██████████| 9/9 [00:25<00:00,  2.87s/it]


Training Loss: 0.8900034427642822


Validation Epoch 10: 100%|██████████| 3/3 [00:01<00:00,  2.26it/s]

Validation Accuracy after Epoch 10: 0.8421052631578947
              precision    recall  f1-score   support

           0       0.00      0.00      0.00         1
           1       0.78      1.00      0.88         7
           2       0.00      0.00      0.00         1
           3       1.00      1.00      1.00         2
           4       0.50      1.00      0.67         1
           5       1.00      1.00      1.00         3
           6       1.00      1.00      1.00         1
           7       1.00      1.00      1.00         1
           9       1.00      1.00      1.00         1
          11       0.00      0.00      0.00         1

    accuracy                           0.84        19
   macro avg       0.63      0.70      0.65        19
weighted avg       0.73      0.84      0.78        19




  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


In [None]:
def test_sample_input(model, tokenizer, label_encoder, sample_input):
    tokenized_input = tokenizer(sample_input, padding=True, truncation=True, return_tensors='pt')

    # Forward pass through the model
    with torch.no_grad():
        model.eval()
        inputs = {'input_ids': tokenized_input['input_ids'],
                  'attention_mask': tokenized_input['attention_mask']}
        outputs = model(**inputs)

    # Get predicted label
    logits = outputs.logits
    predicted_label = torch.argmax(logits, dim=1).item()

    # Decode the predicted label using the provided label encoder
    decoded_label = label_encoder.inverse_transform([predicted_label])[0]

    print(f"Sample Input: {sample_input}")
    print(f"Predicted Action: {decoded_label}")
    print("\n")

# Sample inputs
sample_input = "Health is wealth"
# Test the model for each sample input
test_sample_input(model, tokenizer, label_encoder, sample_input)


Sample Input: Health is wealth
Predicted Action: search


