In [13]:
!python -m pip install -U sentence-transformers



In [14]:
import numpy as np
from typing import Dict

In [15]:
import numpy as np
from typing import Dict
from scipy.special import expit, softmax

import torch
import torch.nn as nn
from transformers import AutoModel, AutoTokenizer, AutoModelForSequenceClassification
from sentence_transformers import SentenceTransformer

class PreDUNES(nn.Module):
    def __init__(
            self,
            twitter_embedding_model: nn.Module,
            twitter_sentiment_tokenizer,
            twitter_sentiment_model: nn.Module,
            reddit_sentiment_tokenizer,
            reddit_sentiment_model: nn.Module,
            twitter_sector_tokenizer,
            twitter_sector_model: nn.Module
        ):
        '''
        Initialize a DUNES model from a set of embeddings and sentiment models.
        Args:
            twitter_embedding_model: huggingface model for Twitter embeddings
            twitter_sentiment_model: huggingface model for Twitter sentiment
            reddit_sentiment_model: huggingface model for Reddit sentiment
            twitter_sector_model: huggingface model for Twitter sector
        '''
        super(PreDUNES, self).__init__()
        self.twitter_embedding_model = twitter_embedding_model
        self.twitter_sentiment_tokenizer = twitter_sentiment_tokenizer
        self.twitter_sentiment_model = twitter_sentiment_model
        self.reddit_sentiment_tokenizer = reddit_sentiment_tokenizer
        self.reddit_sentiment_model = reddit_sentiment_model
        self.twitter_sector_tokenizer = twitter_sector_tokenizer
        self.twitter_sector_model = twitter_sector_model

    def forward(self, prev_tweet, prev_reddit):
        '''
        Forward pass for the DUNES model.
        Args:
            prev_tweet: previous tweet
            curr_tweet: current tweet
            prev_reddit: previous Reddit post
            curr_reddit: current Reddit post
        Returns:
            sentiment: sentiment of the current tweet
            sector: sector of the current tweet
        '''
        # Get the embeddings
        prev_tweet_embedding = self.twitter_embedding_model.encode([prev_tweet])
        

        # Get the sentiment
        prev_tweet_tokens = self.twitter_sentiment_tokenizer(prev_tweet, return_tensors='pt')
        prev_tweet_sentiment = self.twitter_sentiment_model(**prev_tweet_tokens)[0][0].detach().numpy()


        prev_reddit_tokens = self.reddit_sentiment_tokenizer(prev_reddit, return_tensors='pt')
        prev_reddit_sentiment = self.reddit_sentiment_model(**prev_reddit_tokens)[0][0].detach().numpy()

        # Get the sector
        prev_sector_tokens = self.twitter_sector_tokenizer(prev_tweet, return_tensors='pt')
        prev_tweet_sector = self.twitter_sector_model(**prev_sector_tokens)[0][0].detach().numpy()



        return prev_tweet_embedding, prev_tweet_sentiment, prev_reddit_sentiment, prev_tweet_sector


def create_preprocessing_model(
        twitter_embedding: str,
        twitter_sentiment: str,
        reddit_sentiment: str,
        twitter_sector: str
):
    '''
    Initialize a DUNES model from a set of embeddings and sentiment models.
    Feeds the output of these models to a transformer for classification.
    Args:
        twitter_embedding: path to the Twitter embedding huggingface model
        twitter_sentiment: path to the Twitter sentiment huggingface model
        reddit_sentiment: path to the Reddit sentiment huggingface model
        twitter_sector: path to the Twitter sector huggingface model
    '''

    # Load the models
    twitter_embedding_model = SentenceTransformer(twitter_embedding)
    twitter_sentiment_tokenizer = AutoTokenizer.from_pretrained(twitter_sentiment)
    twitter_sentiment_model = AutoModelForSequenceClassification.from_pretrained(twitter_sentiment)
    reddit_sentiment_tokenizer = AutoTokenizer.from_pretrained(reddit_sentiment)
    reddit_sentiment_model = AutoModelForSequenceClassification.from_pretrained(reddit_sentiment)
    twitter_sector_tokenizer = AutoTokenizer.from_pretrained(twitter_sector)
    twitter_sector_model = AutoModelForSequenceClassification.from_pretrained(twitter_sector)

    # Freeze the models
    twitter_sentiment_model.eval()
    twitter_sentiment_model.requires_grad_(False)
    reddit_sentiment_model.eval()
    reddit_sentiment_model.requires_grad_(False)
    twitter_sector_model.eval()
    twitter_sector_model.requires_grad_(False)

    # Create the DUNES model
    model = PreDUNES(
        twitter_embedding_model,
        twitter_sentiment_tokenizer,
        twitter_sentiment_model,
        reddit_sentiment_tokenizer,
        reddit_sentiment_model,
        twitter_sector_tokenizer,
        twitter_sector_model
        )

    return model

def printClassMappings(model, predictions):
    class_mapping = AutoModel.from_pretrained(model).config.id2label
    predictions = softmax(predictions)
    ranking = np.argsort(predictions)
    ranking = ranking[::-1]
    for i in range(predictions.shape[0]):
        l = class_mapping[ranking[i]]
        s = predictions[ranking[i]]
        print(f"{i+1}) {l} {np.round(float(s), 4)}")

def test():
    preprocessing_model = create_preprocessing_model(
        "mixedbread-ai/mxbai-embed-large-v1",
        "cardiffnlp/twitter-roberta-base-sentiment-latest",
        "bhadresh-savani/distilbert-base-uncased-emotion",
        "cardiffnlp/tweet-topic-latest-multi"
    )

    prev_tweet_embedding, prev_tweet_sentiment, prev_reddit_sentiment, prev_tweet_sector = preprocessing_model(
        "@WholeMarsBlog Headline is misleading. Starlink can obviously offer far more robust positioning than GPS, as it will have ~1000X more satellites over time. Not all will have line of sight to users, but still &gt;10X GPS &amp; far stronger signal. Just not today’s problem.",
        "We know who controls the media. The same corporations who have wreaked havoc on the globe for decades, if not centuries, the big banks who financed them, and the governments who turned a blind eye to the destruction. The same entities who have brought us to the precipice of destruction - quite possibly condemning us, and our progeny to an unlivable climate They have tried to stop you at every turn, and yet you persist for the good of humanity. We love you, Elon! Keep up the good work! As you have said, we must never let the light of human consciousness fade - never!"
    )

    print("prev_tweet_embedding:", prev_tweet_embedding)
    print("prev_tweet_sentiment:", softmax(prev_tweet_sentiment))
    

    print("prev_reddit_sentiment:")
    printClassMappings("bhadresh-savani/distilbert-base-uncased-emotion", prev_reddit_sentiment)
    print("prev_tweet_sector:")
    printClassMappings("cardiffnlp/tweet-topic-latest-multi", prev_tweet_sector)



In [16]:
test()

Some weights of the model checkpoint at cardiffnlp/twitter-roberta-base-sentiment-latest were not used when initializing RobertaForSequenceClassification: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
- This IS expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


prev_tweet_embedding: [[ 0.34489495  0.13846159 -0.32645586 ... -0.717675    0.28787875
  -0.14000623]]
prev_tweet_sentiment: [0.14040029 0.57469934 0.2849003 ]
prev_reddit_sentiment:
1) anger 0.9387
2) sadness 0.0523
3) joy 0.0054
4) love 0.0019
5) fear 0.0014
6) surprise 0.0004
prev_tweet_sector:


Some weights of RobertaModel were not initialized from the model checkpoint at cardiffnlp/tweet-topic-latest-multi and are newly initialized: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


1) science_&_technology 0.9853
2) news_&_social_concern 0.0066
3) business_&_entrepreneurs 0.0027
4) other_hobbies 0.0008
5) diaries_&_daily_life 0.0007
6) learning_&_educational 0.0006
7) film_tv_&_video 0.0005
8) fitness_&_health 0.0004
9) celebrity_&_pop_culture 0.0004
10) gaming 0.0003
11) travel_&_adventure 0.0003
12) sports 0.0002
13) relationships 0.0002
14) youth_&_student_life 0.0002
15) music 0.0002
16) food_&_dining 0.0002
17) arts_&_culture 0.0002
18) family 0.0002
19) fashion_&_style 0.0001
