In [147]:
from fastapi import FastAPI, Request
from pydantic import BaseModel
import torch
from transformers import AutoTokenizer, DistilBertConfig

import sys 
import os
import torch.nn.functional as F

    
import torch.nn as nn
from transformers import DistilBertModel

app = FastAPI()

class TextInput(BaseModel):
    text: str
    
    
class DistilBertForSequenceClassification(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.num_labels = config.num_labels

        self.distilbert = DistilBertModel.from_pretrained('distilbert-base-uncased')
        self.pre_classifier = nn.Linear(config.hidden_size, config.hidden_size)
        self.classifier = nn.Linear(config.hidden_size, config.num_labels)
        self.dropout = nn.Dropout(config.seq_classif_dropout)

        nn.init.xavier_normal_(self.classifier.weight)

    def forward(self, input_ids=None, attention_mask=None, head_mask=None, labels=None):
        distilbert_output = self.distilbert(input_ids=input_ids,
                                            attention_mask=attention_mask,
                                            head_mask=head_mask)
        hidden_state = distilbert_output[0]                    
        pooled_output = hidden_state[:, 0]                   
        pooled_output = self.pre_classifier(pooled_output)   
        pooled_output = nn.ReLU()(pooled_output)             
        pooled_output = self.dropout(pooled_output)        
        logits = self.classifier(pooled_output) 
        probs = F.softmax(logits, dim = -1)
        return probs
    

config = DistilBertConfig(num_labels=5)
model = DistilBertForSequenceClassification(config=config)
model.load_state_dict(torch.load("distilbert_model_state.pth", weights_only=False))
model.eval()


tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")

class TextInput(BaseModel):
    text: str

@app.get("/")
async def read_root():
    return {"message": "Welcome to the FastAPI application!"}

@app.post("/predict/")
async def predict(input: TextInput):
    inputs = tokenizer(input.text, return_tensors="pt", truncation=True, padding=True, max_length=512)
    with torch.no_grad():
        outputs = model(**inputs)
    logits = outputs
    predicted_class = logits.argmax().item()
    return {"prediction": predicted_class}



In [148]:
{"rap": 0, "rock": 1, "rb": 2, "pop": 3, "country": 4}

{'rap': 0, 'rock': 1, 'rb': 2, 'pop': 3, 'country': 4}

In [149]:
import requests

# URL of the FastAPI endpoint
url = "http://127.0.0.1:80/predict/"



In [150]:
best_model = "3_0.72_full.pth"
config = DistilBertConfig(num_labels=5)
model = DistilBertForSequenceClassification(config=config)
model = load_model(best_model, model)

FileNotFoundError: [Errno 2] No such file or directory: '3_0.72_full.pth'

In [None]:
torch.save(model.state_dict(), "distilbert_model_state.pth")


In [None]:
lyrics = """I was told the true definition of a man was to never cry
Work 'til you tired (yeah), got to provide (yeah)
Always be the rock for my fam, protect them by all means
And give you the things that you need, baby
Our relationship is suffering
Tryna give you what I never had
You say I don't know to love you, baby
Well, I say show me the way
I keep my feelings deep inside I
Shadow them with my pride eye
I'm trying desperately, baby, just work with me
Teach me how to love
Show me the way to surrender my heart, girl, I'm so lost (yeah)
Teach me how to love (yeah)
How I can get my emotions involved (yeah), teach me
Show me how to love
Show me the way to surrender my heart, girl, I'm so lost (lost)
Teach me how to love
How I can get my emotions involved (yeah)
Teach me (uh), how to love
I was always taught to be strong
Never let them think you care at all
Let no one get close to me before you and me
I done shared things witchu, girl, about my past that I'd never tell
To anyone else, no, just keep it to myself, yeah
Girl, I know I lack affection and expressin' my feelings
It took me a minute to come and admit this but
See I'm really tryna change now
Wanna love you better, show me how
I'm tryin' desperately, baby, please work with me
Teach me how to love
Show me the way to surrender my heart (my heart), girl, I'm so lost (yeah)
Teach me how to love
How I can get my emotions involved (yeah), teach me (teach me)
Show me how to love
Show me the way to surrender my heart, girl, I'm so lost (oh)
Teach me how to love
How I can get my emotions involved (yeah)
Teach me (teach me), how to love
Ain't nobody ever took the time to try to teach me
What love was but you
And I ain't never trust no one enough
To let 'em tell me what to do
Teach me how to really show it
And show me how to really love you, baby
Teach me, please, just show me, yeah
'Cause I'm willing to let go of my fears
Girl, I'm serious about all that I've said
Girl, I wanna love you with all my heart (wanna love you with all my heart)
Baby, show me where to start
Teach me how to love
Show me the way to surrender my heart, girl, I'm so lost (ooh)
Teach me how to love (love, love, love)
How I can get my emotions involved, teach me (involved, yeah)
Show me how to love
Show me the way to surrender my heart, girl, I'm so lost (so lost)
Teach me how to love (yeah)
How I can get my emotions involved (teach me)
Teach me how, how to love
Teach me how to love you, baby (uh)
Girl, just teach me how to love you better
You know I wanna love you better, girl
Oh, yeah, yeah, yeah

"""

In [151]:
data = {
    "text": lyrics
}

# Send a POST request to the API
response = requests.post(url, json=data)

# Print the response from the API
print(response.json())

{'rap': '14.14%', 'rock': '0.37%', 'rb': '79.55%', 'pop': '5.91%', 'country': '0.02%'}
