## Multi-task modelling

The key components of this multi-task learning implementation:

1. Model Architecture
   
- The model uses BERT as a shared backbone
- It has two task-specific heads:
    - Sentiment classifier (multi-class: positive/neutral/negative with configurable neutral thresholds)
    - Topic classifier (multi-class)
- The shared layers learn common features useful for both tasks  
<br>
2. Key Components:

- TextDataset: Handles data preprocessing and tokenization
- MultitaskModel: The main model combining both tasks
- Training loop with combined loss function  

<br>
3. Loss Function:

- Uses separate loss functions for each task
- Combines them into a total loss for optimization  

<br>
4. Benefits of Multi-task Learning:

- Shared representation learning
- Better generalization
- More efficient use of data
- Reduced overfitting through regularization

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from transformers import BertTokenizer, BertModel


In [None]:

class TextDataset(Dataset):
    def __init__(self, texts, sentiment_labels, topic_labels, tokenizer, max_length=128):
        self.texts = texts
        self.sentiment_labels = sentiment_labels  # Now 0: negative, 1: neutral, 2: positive
        self.topic_labels = topic_labels
        self.tokenizer = tokenizer
        self.max_length = max_length
    
    def __len__(self):
        return len(self.texts)
    
    def __getitem__(self, idx):
        text = self.texts[idx]
        encoding = self.tokenizer(
            text,
            add_special_tokens=True,
            max_length=self.max_length,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )
        
        return {
            'input_ids': encoding['input_ids'].flatten(),
            'attention_mask': encoding['attention_mask'].flatten(),
            'sentiment_label': torch.tensor(self.sentiment_labels[idx], dtype=torch.long),
            'topic_label': torch.tensor(self.topic_labels[idx], dtype=torch.long)
        }

class MultitaskModel(nn.Module):
    def __init__(self, num_topics, num_sentiments=3):  # Changed to 3 sentiments
        super(MultitaskModel, self).__init__()
        self.bert = BertModel.from_pretrained('bert-base-uncased')
        hidden_size = self.bert.config.hidden_size
        
        # Shared layers
        self.dropout = nn.Dropout(0.1)
        
        # Task-specific layers
        self.sentiment_classifier = nn.Linear(hidden_size, num_sentiments)
        self.topic_classifier = nn.Linear(hidden_size, num_topics)
    
    def forward(self, input_ids, attention_mask):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        pooled_output = outputs.pooler_output
        pooled_output = self.dropout(pooled_output)
        
        # Task-specific predictions
        sentiment_logits = self.sentiment_classifier(pooled_output)
        topic_logits = self.topic_classifier(pooled_output)
        
        return sentiment_logits, topic_logits


In [None]:

def train_model(texts, sentiment_labels, topic_labels, num_topics, num_epochs=3):
    # Initialize tokenizer and model
    tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
    model = MultitaskModel(num_topics=num_topics, num_sentiments=3)  # Changed to 3 sentiments
    
    # Create dataset and dataloader
    dataset = TextDataset(texts, sentiment_labels, topic_labels, tokenizer)
    dataloader = DataLoader(dataset, batch_size=2, shuffle=True)
    
    # Initialize optimizer
    optimizer = optim.AdamW(model.parameters(), lr=2e-5)
    
    # Loss functions
    sentiment_criterion = nn.CrossEntropyLoss()
    topic_criterion = nn.CrossEntropyLoss()
    
    # Training loop
    model.train()
    for epoch in range(num_epochs):
        for batch in dataloader:
            optimizer.zero_grad()
            
            input_ids = batch['input_ids']
            attention_mask = batch['attention_mask']
            sentiment_labels = batch['sentiment_label']
            topic_labels = batch['topic_label']
            
            sentiment_logits, topic_logits = model(input_ids, attention_mask)
            
            sentiment_loss = sentiment_criterion(sentiment_logits, sentiment_labels)
            topic_loss = topic_criterion(topic_logits, topic_labels)
            total_loss = sentiment_loss + topic_loss
            
            total_loss.backward()
            optimizer.step()
            
        print(f"Epoch {epoch+1}/{num_epochs}, Loss: {total_loss.item():.4f}")
    
    return model, tokenizer



In [None]:
def predict(model, tokenizer, text, topic_names=None, neutral_thresholds=(0.4, 0.6)):
    """
    Make predictions for a new text input.
    
    Args:
        model: Trained MultitaskModel
        tokenizer: BERT tokenizer
        text: Input text string
        topic_names: Optional list of topic names for readable output
        neutral_thresholds: Tuple of (lower, upper) thresholds for neutral sentiment
    
    Returns:
        Dictionary containing sentiment and topic predictions
    """
    # Prepare the model for evaluation
    model.eval()
    
    # Tokenize the input text
    encoding = tokenizer(
        text,
        add_special_tokens=True,
        max_length=128,
        padding='max_length',
        truncation=True,
        return_tensors='pt'
    )
    
    # Get model predictions
    with torch.no_grad():
        sentiment_logits, topic_logits = model(
            encoding['input_ids'],
            encoding['attention_mask']
        )
        
        # Get predictions
        sentiment_probs = torch.softmax(sentiment_logits, dim=1)[0]
        topic_pred = torch.softmax(topic_logits, dim=1)[0]
        
        # Determine sentiment based on probabilities and thresholds
        sentiment_label = sentiment_probs.argmax().item()
        sentiment_confidence = float(sentiment_probs[sentiment_label])
        
        # Map numerical labels to text labels
        sentiment_map = {0: 'Negative', 1: 'Neutral', 2: 'Positive'}
        sentiment_text = sentiment_map[sentiment_label]
        
        topic_label = topic_pred.argmax().item()
    
    # Prepare results
    results = {
        'text': text,
        'sentiment': {
            'label': sentiment_text,
            'confidence': sentiment_confidence,
            'probabilities': {
                'negative': float(sentiment_probs[0]),
                'neutral': float(sentiment_probs[1]),
                'positive': float(sentiment_probs[2])
            }
        },
        'topic': {
            'label': topic_names[topic_label] if topic_names else f"Topic {topic_label}",
            'confidence': float(topic_pred[topic_label])
        }
    }
    
    return results


In [None]:

# Example usage
# In reality, this dataset will be imported, cleaned, and transformed. 
# Typically this will be followed by a train/val/test split prior to calling
# the train_model function

if __name__ == "__main__":
    # Training data
    texts = [
        "This product is amazing and works great!",
        "The service was terrible and disappointing.",
        "The product works as expected.",
        "The market outlook remains negative.",
        "This is okay, nothing special."
    ]
    # 0: negative, 1: neutral, 2: positive
    sentiment_labels = [2, 0, 1, 0, 1]  
    topic_labels = [2, 1, 0, 3, 2]      # Example topics: 0: Technology, 1: Service, 2: Product, 3: Finance
    
    # Train the model
    topic_names = ['Technology', 'Service', 'Product', 'Finance']
    model, tokenizer = train_model(texts, sentiment_labels, topic_labels, num_topics=len(topic_names))
    
    # Make predictions for new text
    new_text = "The new AI features in this product are mindblowing!"
    predictions = predict(model, tokenizer, new_text, topic_names, neutral_thresholds=(0.4, 0.6))
    
    # Print results
    print("\nPrediction Results:")
    print(f"Text: {predictions['text']}")
    print(f"Sentiment: {predictions['sentiment']['label']} "
          f"(Confidence: {predictions['sentiment']['confidence']:.2f})")
    print("Sentiment Probabilities:")
    for sentiment, prob in predictions['sentiment']['probabilities'].items():
        print(f"  {sentiment.capitalize()}: {prob:.2f}")
    print(f"Topic: {predictions['topic']['label']} "
          f"(Confidence: {predictions['topic']['confidence']:.2f})")

Epoch 1/3, Loss: 2.0385
Epoch 2/3, Loss: 2.3136
Epoch 3/3, Loss: 2.4543

Prediction Results:
Text: The new AI features in this product are revolutionary!
Sentiment: Neutral (Confidence: 0.39)
Sentiment Probabilities:
  Negative: 0.24
  Neutral: 0.39
  Positive: 0.37
Topic: Product (Confidence: 0.38)


In [None]:
# Make predictions for new text
new_text = ["The new AI features in this product are mindblowing!",
            "The markets will open higher today"]

for txt in new_text:
      predictions = predict(model, tokenizer, txt, topic_names, neutral_thresholds=(0.4, 0.6))

      # Print results
      print("Prediction Results:")
      print(f"Text: {predictions['text']}\n")
      print(f"Sentiment: {predictions['sentiment']['label']} "
            f"(Confidence: {predictions['sentiment']['confidence']:.2f})")
      print("Sentiment Probabilities:")
      for sentiment, prob in predictions['sentiment']['probabilities'].items():
            print(f"  {sentiment.capitalize()}: {prob:.2f}")
      print(f"\nTopic: {predictions['topic']['label']} "
            f"(Confidence: {predictions['topic']['confidence']:.2f})")
      
      print("\n    ****  ****  ****  ****  ****  ****  ****  **** \n")

Prediction Results:
Text: The new AI features in this product are mindblowing!

Sentiment: Positive (Confidence: 0.41)
Sentiment Probabilities:
  Negative: 0.21
  Neutral: 0.39
  Positive: 0.41

Topic: Product (Confidence: 0.35)

    ****  ****  ****  ****  ****  ****  ****  **** 

Prediction Results:
Text: The markets will open higher today

Sentiment: Neutral (Confidence: 0.41)
Sentiment Probabilities:
  Negative: 0.29
  Neutral: 0.41
  Positive: 0.30

Topic: Product (Confidence: 0.34)

    ****  ****  ****  ****  ****  ****  ****  **** 

