In [1]:
import pandas as pd
import torch 
from torch import nn
from transformers import RobertaTokenizer, RobertaConfig, RobertaModel, RobertaForTokenClassification
from datasets import Dataset
from tqdm import tqdm 
from collections import Counter
from transformers import AutoTokenizer, AutoModel, PreTrainedTokenizerFast, RobertaTokenizerFast
import numpy as np

In [2]:
#model class 
class Model(nn.Module):
    def __init__(self):
        #def __init__(self):
        super(Model,self).__init__()
        self.model = RobertaModel.from_pretrained('roberta-base')

        #since we have three classes 
        self.l1 = nn.Linear(768, 3)

        #normalizes probabilities to sum to 1
        self.sig = nn.Sigmoid()
        #self.ixList = ixList
        
    def mean_pooling(self, token_embeddings, attention_mask): 
        input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
        return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
    
    #outIndices tells us the indices of the tokens corresponding to our word of interest
    #for each instance in our batch
    def forward(self, input_ids, attention_mask, outIndices): 
        
        #encode sentence and get mean pooled sentence representation 
        output = self.model(input_ids, attention_mask=attention_mask)
        
        #now we just have outIndices come to us in the forward pass 
        #outIndices = [self.ixList[ix] for ix in index]
        embeddingMeans = []
        batchIter = 0
        for batchIter in range(input_ids.shape[0]): 
            
            #get the last layer of the model 
            hiddenStates = output[0]
            
            #get the embeddings corresponding to the entity we're interested in 
            tokStates = [hiddenStates[batchIter][tokIndex,:] for tokIndex in outIndices[batchIter]]
            
            #take the mean over all embeddings for an entity 
            embeddingMean = torch.stack(tokStates).mean(dim=0)
            
            #append this so we get the mean embedding for each 
            #training example in this batch 
            embeddingMeans.append(embeddingMean) 
            #embeddingMeans.append(hiddenStates[batchIter][outIndices[batchIter][0],:])
        
        #we stack because this is for an entire batch 
        embeddingMeans = torch.stack(embeddingMeans)
        """
        working code just used this!
        embeddingMeans = self.mean_pooling(output[0], attention_mask)
        """
        probs = self.sig(self.l1(embeddingMeans)).squeeze()
        
        return probs

In [3]:
torch.cuda.is_available()

True

In [4]:
#load best model from training
STATE_PATH  = "/shared/3/projects/benlitterer/podcastData/hostGuestModels/initialModel/bestF1Params"
deviceNum = 1
device = torch.device("cuda:" + str(deviceNum) if torch.cuda.is_available() else "cpu")
print(device)
model = Model().to(device)
model.load_state_dict(torch.load(STATE_PATH))

cuda:1


Some weights of RobertaModel were not initialized from the model checkpoint at roberta-base and are newly initialized: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


<All keys matched successfully>

In [5]:
#load dataset
df = pd.read_json("/shared/3/projects/benlitterer/podcastData/processed/floydMonth/floydMonthDataClean.jsonl", orient="records", lines=True, nrows=10000)

In [6]:
toKeep = ["potentialOutPath", "transcript", "rssUrl", "epTitle", "epDescription", "transEnts", "transStarts", "transEnds", "transTypes"]
df = df[toKeep].explode(["transEnts", "transStarts", "transEnds", "transTypes"])
df = df[df["transTypes"] == "PERSON"]

In [7]:
#removing mentions after a certain point 
#first get the number of words before the entity, we only use < 350 to train, so go with that 
def getEntPos(inRow): 
    return len(inRow["transcript"][:inRow["transStarts"]].split())

df["entPos"] = df.apply(getEntPos, axis=1)
df = df[df["entPos"] < 350]

In [8]:
df["transEntLen"] = df["transEnts"].apply(lambda x: len(x.split()))
df = df[df["transEntLen"] == 2]

In [14]:
df.shape

(9854, 14)

In [15]:
df = df.sort_values(["potentialOutPath", "transEnts"]) 

In [16]:

BEFORE_BUFF = 50
AFTER_BUFF=50
#PUNCH IN HERE
def getSnippet(row): 
    #find where the entity starts quick 
   # row["snippetStart"] = trainDf.apply(lambda x: x["entSnippets"].lower().find(x["ent"].lower()), axis=1)
    
    snippet = row["transcript"]
    entStart = row["transStarts"]
    entEnd = row["transEnds"]

    
    beforeWords = snippet[0:entStart].split(" ")
    if len(beforeWords) >= BEFORE_BUFF: 
        buffStart = " ".join(beforeWords[-BEFORE_BUFF:]) 
    else: 
        buffStart = " ".join(beforeWords) 

    afterWords = snippet[entEnd:len(snippet)].split(" ")

    if len(afterWords) >= AFTER_BUFF: 
        buffEnd = " ".join(afterWords[:AFTER_BUFF]) 
    else: 
        buffEnd = " ".join(afterWords) 
    return [buffStart, snippet[entStart:entEnd], buffEnd]
            

df[["left", "ent", "right"]] = pd.DataFrame(df.apply(getSnippet, axis=1).tolist(), index=df.index)

KeyError: 'transcript'

In [12]:
#for the sake of memory 
df = df.drop(columns=["transcript"])

In [17]:
df["entSnippets"] = df["left"] + df["ent"] + df["right"] 

#df = df[["left", "right", "ent",'transStarts', 'transEnds', 'groundTruth', 'entSnippets']]
#df = df.dropna()

In [21]:
df["snippetStart"] = df.apply(lambda x: x["entSnippets"].lower().find(x["ent"].lower()), axis=1)
df["snippetEnd"] = df["snippetStart"] + df["transEnds"] - df["transStarts"]

def extractEnt(inRow): 
    return inRow["entSnippets"][inRow["snippetStart"]:inRow["snippetEnd"]]

df["extractedEnt"] = df.apply(extractEnt, axis=1)

In [22]:
tokenizer = RobertaTokenizerFast.from_pretrained('roberta-base', max_length=512, padding="max_length", truncation=True)

In [23]:
tokenized = []
for snip in df["entSnippets"]: 
    tokenized.append(tokenizer(snip, padding = "max_length", truncation=True, return_offsets_mapping=True))

df = pd.concat([df.reset_index(), pd.DataFrame.from_records(tokenized)], axis=1) 

In [24]:
#find the token indices which correspond to our entity 
def getTokenIndices(start, end, offsets):
    """
    print(start) 
    print(end) 
    print(offsets[:20]) 
    """

    currIndices = []
    for j, offset in enumerate(offsets): 
        offsetL, offsetR = offset
        if offsetL >= start and offsetR <= end: 
            currIndices.append(j)

    return currIndices

In [25]:
df["posTokens"] = df.apply(lambda row: getTokenIndices(row["snippetStart"], row["snippetEnd"], row["offset_mapping"]), axis=1)

#drop extra information about location of tokens that aren't those of interest
df = df.drop(columns=["offset_mapping"])

labList = []
for i, row in df.iterrows(): 
    tokCount = sum(row["attention_mask"])
    paddingLen = len(row["attention_mask"]) - tokCount
    
    labels = ([0] * tokCount) + ([2] * paddingLen)
    
    for posIndex in row["posTokens"]: 
        labels[posIndex] = 1
    
    labList.append(labels) 

df["labels"] = labList

df["entsTokenized"] = df.apply(lambda row: [tokenizer.decode(row["input_ids"][i]) for i in row["posTokens"]], axis=1) 

In [26]:
extractionErrorCount = len(df[df["extractedEnt"].apply(lambda x: x.lower()) != df["ent"].apply(lambda x: x.lower())]) 
print(f"Number of entities where we have error from extraction of entity: {extractionErrorCount}")

Number of entities where we have error from extraction of entity: 0


In [27]:
df = df.drop(columns=["index"])

In [28]:
df = df.dropna(subset=["attention_mask", "input_ids", "posTokens"])

#add an index to be used for getting the position of tokens later 
df = df.reset_index(drop=True).reset_index()

In [29]:
#tokenRef = dict(zip(df["index"], df["posTokens"]))
tokenIxList = list(df["posTokens"])

In [30]:
BATCH_SIZE=4
dataset = Dataset.from_pandas(df)
dataset.set_format(type='torch', columns=["input_ids", "attention_mask", "index"])
loader = torch.utils.data.DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=False)

In [31]:
#input ids, mask, indices 
predList = []
probList = []
#TODO: remove full transcript column here to save memory 
for batch in tqdm(loader):
    input_ids = batch["input_ids"].to(device) 
    attention_mask = batch["attention_mask"].to(device) 
    index = batch["index"]
    outIndices = [tokenIxList[i] for i in index]
    
    probs = model(input_ids, attention_mask, outIndices) #.to(torch.float32)

    #if the last batch has only 1 row, we need to add another dimension in 
    if len(probs.shape) == 1: 
        probs = probs.unsqueeze(0)
    
    probList += probs.to("cpu").tolist()
    preds = torch.max(probs, 1).indices.to(int).cpu().tolist()
    predList += preds

  0%|          | 0/2464 [00:00<?, ?it/s]

100%|██████████| 2464/2464 [04:34<00:00,  8.98it/s]


In [32]:
df["pred"] = predList
df["prob"] = probList

In [38]:
outPath = "/shared/3/projects/benlitterer/podcastData/hostIdentification/hostGuestPredictions/10000LongPredictions.json"
df[["potentialOutPath", "rssUrl", "ent", "pred", "prob"]].to_json(outPath, orient="records", lines=True)

In [34]:
def getMode(inList): 
    if len(inList) == 1: 
        return inList[0]
    
    data = Counter(inList)
    modeVal, modeCount = data.most_common(1)[0]

    #we default to neither if we have a split decision
    
    if modeCount == 1: 
        return 2
    else: 
        return modeVal 
    
    return modeVal

#here we take the index of the maximum probability prediction 
#after mean pooling over columns 
def getConfidenceAggregation(inList): 
    inList = np.array(inList)
    return np.argmax(np.mean(inList, axis=0))

#we take in a 2d array of shape n x 3
#get the prediction for the row with the highest probability 
def getMostConfident(inList): 

    maxVal = 0 
    maxValIx = 2
    for row in inList: 
        for colNum, item in enumerate(row): 

            #if we have a new highest value, update 
            #note that maxValIx is just our prediction of 0, 1, or 2
            if item > maxVal: 
                maxVal = item 
                maxValIx = colNum
    return maxValIx

aggDf = df[["potentialOutPath", "ent", "pred", "prob"]].groupby(["potentialOutPath", "ent"]).agg(list)
aggDf["modalPred"] = aggDf["pred"].apply(getMode)
aggDf["confPred"] = aggDf["prob"].apply(getMostConfident)
aggDf["meanAggPred"] = aggDf["prob"].apply(getConfidenceAggregation)


In [35]:
aggArr = aggDf[["modalPred", "confPred", "meanAggPred"]].T.values.tolist()

In [36]:
np.corrcoef(aggArr)

array([[1.        , 0.96811818, 0.9708293 ],
       [0.96811818, 1.        , 0.99684888],
       [0.9708293 , 0.99684888, 1.        ]])

In [246]:
outPath = "/shared/3/projects/benlitterer/podcastData/hostIdentification/hostGuestPredictions/1000predictions.json"
aggDf.to_json(outPath, orient="records", lines=True)

Unnamed: 0_level_0,Unnamed: 1_level_0,pred
potentialOutPath,ent,Unnamed: 2_level_1
/anchor.fm/0a/httpsanchor.fms59db584podcastplay13749364https3A2F2Fd3ctxlq1ktw2nl.cloudfront.net2Fstaging2F202005142F995a5a1ac4131000030f0993e5afc00f.m4aMERGED,Don Chanel,0
/anchor.fm/20/httpsanchor.fms126c0978podcastplay14603313https3A2F2Fd3ctxlq1ktw2nl.cloudfront.net2Fproduction2F20204312F7840600244100257f2a88169a22.m4aMERGED,Brody Myers,0
/anchor.fm/20/httpsanchor.fms126c0978podcastplay14603313https3A2F2Fd3ctxlq1ktw2nl.cloudfront.net2Fproduction2F20204312F7840600244100257f2a88169a22.m4aMERGED,Micah Wilcox,1
/anchor.fm/20/httpsanchor.fms2e2929cpodcastplay14311720https3A2F2Fd3ctxlq1ktw2nl.cloudfront.net2Fproduction2F20204262F76791003480002c47e0a8a50627.mp3MERGED,Shadow Shkowski,1
/anchor.fm/20/httpsanchor.fms67be020podcastplay14204913https3A2F2Fd3ctxlq1ktw2nl.cloudfront.net2Fproduction2F20204242F7612718532000181158a6abd0e8.mp3MERGED,Eppie Ludwig,1
...,...,...
/traffic.megaphone.fm/6A/httpstraffic.megaphone.fmAPO6964160656.mp3MERGED,Tony Gill,1
/traffic.megaphone.fm/73/httpstraffic.megaphone.fmAPO7738284721.mp3MERGED,Anita Johnston,1
/traffic.megaphone.fm/86/httpstraffic.megaphone.fmAPO8028686917.mp3MERGED,Brooke Castillo's,1
/www.podtrac.com/re/httpswww.podtrac.comptsredirect.mp3pdst.fmer.zencastr.comrtraffic.megaphone.fmHUMAN2742163750.mp3updated1666811397MERGED,Chris Gray,1
