In [1]:
import os

from pathlib import Path
from datetime import datetime, timedelta
from dotenv import load_dotenv
from itertools import product

import numpy as np
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go

from bson import ObjectId
from IPython.display import display
from pymongo import MongoClient, errors
from pymongo.database import Database
from tqdm import tqdm

load_dotenv("../../.env")
pd.options.plotting.backend = "plotly"

# Data collection

In [2]:
DATABASE = 'insightfinder-dev'

TOPICS_COLLECTION = 'topics'
TOPIC_EMBEDDINGS_COLLECTION = 'topic_embeddings'

MONGO_HOST = os.getenv("MONGO_HOST")

In [3]:
def mongo_transaction(func):
    """
    A decorator to execute a MongoDB operation with error handling.
    """
    def wrapper(*args, **kwargs):
        try:
            with MongoClient(MONGO_HOST) as mongo_client:
                db = mongo_client[DATABASE]
                return func(db, *args, **kwargs)
        except errors.PyMongoError as e:
            print(f"MongoDB error: {e}")
        except Exception as e:
            print(f"An unexpected error occurred: {e}")
    return wrapper


@mongo_transaction
def retrieve_topics_assignments(db: Database, start_date: str, end_date: str, skip_noise_topic: bool = True):
    topics_collection = db[TOPICS_COLLECTION]
    query = {"topic_start_date": {"$eq": start_date}, "topic_end_date": {"$eq": end_date}}
    if skip_noise_topic:
        query["topic_index"] = {"$ne": -1}
    return list(topics_collection.find(filter=query))


@mongo_transaction
def get_topic_embeddings(db: Database, topic_ids: list[ObjectId]):
    topic_embeddings_collection = db[TOPIC_EMBEDDINGS_COLLECTION]
    query = {"_id": {"$in": topic_ids}}
    return list(topic_embeddings_collection.find(filter=query))

In [4]:
# get week start and end dates

start_date = "2024-05-01"
end_date = "2024-08-01"

mondays = pd.date_range(start_date, end_date, freq="W-MON")
date_ranges = []

for monday, next_monday in zip(mondays, mondays[1:]):
    sunday = next_monday - timedelta(days=1)
    date_ranges.append((monday.strftime("%Y-%m-%d"), sunday.strftime("%Y-%m-%d")))

len(date_ranges)

12

In [5]:
# retrieve topics data per week

topic_assignments = {}
topic_embeddings = {}

for sd, ed in tqdm(date_ranges):
    ta = retrieve_topics_assignments(sd, ed)
    ta = pd.DataFrame(ta)
    topic_assignments[(sd, ed)] = ta

    tids = ta["topic_id"].tolist()
    for x in get_topic_embeddings(tids):
        topic_embeddings[x["_id"]] = x["embedding"]

print(len(topic_assignments))
print(len(topic_embeddings))

100%|███████████████████████████████████████████| 12/12 [00:14<00:00,  1.20s/it]

12
1238





# Dataset creation

- For each week, extract trending topics by using top 10% of topics as trending
- Identify similar topics from previous week based on similarity scores
- Label those topics as trending for next week

In [6]:
def pairwise_cosine_similarity(array1, array2):
    """
    Calculate pairwise cosine similarity between rows of two 2D numpy arrays.

    :param array1: A 2D numpy array where each row is a vector.
    :param array2: A 2D numpy array where each row is a vector.
    :return: A 2D numpy array containing the cosine similarity scores.
    """
    # Normalize the vectors in each array
    norms1 = np.linalg.norm(array1, axis=1, keepdims=True)
    norms2 = np.linalg.norm(array2, axis=1, keepdims=True)

    # Handle cases where norm is zero to avoid division by zero
    norms1[norms1 == 0] = 1
    norms2[norms2 == 0] = 1

    normalized_array1 = array1 / norms1
    normalized_array2 = array2 / norms2

    # Compute the pairwise cosine similarity
    similarity_matrix = np.dot(normalized_array1, normalized_array2.T)

    return similarity_matrix


def match_topics(
    sim_matrix,
    src_topics: list,
    dst_topics: list,
    topk: int = 5,
    min_similarity: float = 0.9,
    return_scores: bool = False,
):
    """
    Matches src_topics to dst_topics.
    min_similarity - minimum cosine similarity for a topic pair to be considered similar
    topk - maximum number of dst topics to be matched to each src topic
    """
    result = {}
    for idx, topic_id in enumerate(src_topics):
        similarities = sorted(zip(dst_topics, sim_matrix[idx]), key=lambda x: x[-1], reverse=True)
        similarities = list(filter(lambda x: x[-1] >= min_similarity, similarities))
        similarities = similarities[:topk]
        if return_scores:
            result[topic_id] = similarities[:]
        else:
            result[topic_id] = [x[0] for x in similarities]
    return result


def get_topic_counts(topic_assignments):
    """
    Calculates the total number of articles/documents of each topic inside `topic_assignments`
    """
    result = topic_assignments.groupby(by="topic_id")["document_id"].nunique()
    result = result.reset_index()
    result = result.rename(columns={"document_id": "count"})
    result = result.set_index("topic_id")["count"].to_dict()
    return result


def get_trending_topics(topic_assignments, top_pct: float = 0.05):
    topic_counts = get_topic_counts(topic_assignments)
    topic_counts = pd.Series(topic_counts)
    threshold = topic_counts.quantile(1 - top_pct)
    trending_topics = topic_counts[topic_counts >= threshold]
    return trending_topics.index.tolist()


def find_matches_to_trending_topics(prev_week_date_range, next_week_date_range, topk: int = 5, min_similarity: float = 0.9):
    """
    Matches trending topics from next_topics_date_range to topics from prev_week_date_range by:
    1. Calculating the similarity matrix between start_topics and next_topics
    2. Matching start_topics to next_topics using similarity scores
    3. Applying the min_similarity threshold
    4. Yielding at most topk matches per start topic
    """

    dst_topics = topic_assignments[prev_week_date_range]["topic_id"].unique().tolist()
    src_topics = get_trending_topics(topic_assignments[next_week_date_range], top_pct=0.1)

    dst_topics_embeddings = np.array([topic_embeddings[topic_id] for topic_id in dst_topics])
    src_topics_embeddings = np.array([topic_embeddings[topic_id] for topic_id in src_topics])

    similarity_matrix = pairwise_cosine_similarity(src_topics_embeddings, dst_topics_embeddings)
    topics_matching = match_topics(similarity_matrix, src_topics, dst_topics, topk, min_similarity)

    topics_matched_trending = []
    for matched_topics in topics_matching.values():
        topics_matched_trending.extend(matched_topics)
    return list(set(topics_matched_trending))

In [7]:
# get all matches of previous week's topics to next week's trending topics

topics_matched_trending = []
date_ranges = list(topic_assignments)

for prev_week_date_range, next_week_date_range in zip(date_ranges, date_ranges[1:]):
    topics_matched_trending.extend(find_matches_to_trending_topics(
        prev_week_date_range,
        next_week_date_range,
        topk=10,
        min_similarity=0.85,
    ))

len(topics_matched_trending)

329

In [8]:
# cross join topics to all dates in their date range (required to fill in NAs in topics daily counts)

dates_per_week = {}
for date_range in date_ranges:
    dates_per_week[date_range] = pd.date_range(date_range[0], date_range[-1]).strftime("%Y-%m-%d").tolist()

topics_per_week = {}
for date_range, ta in topic_assignments.items():
    topics_per_week[date_range] = ta["topic_id"].unique().tolist()

topics_dates_index = []
for date_range in date_ranges:
    dates_in_week = dates_per_week[date_range]
    topics_in_week = topics_per_week[date_range]
    topics_dates_index.extend(list(product(topics_in_week, dates_in_week)))

topics_dates_index = pd.DataFrame(topics_dates_index, columns=["topic_id", "date"])

print(topics_dates_index.shape)
display(topics_dates_index.head())

(8666, 2)


Unnamed: 0,topic_id,date
0,66abfb3f5d3be4373f7683f0,2024-05-06
1,66abfb3f5d3be4373f7683f0,2024-05-07
2,66abfb3f5d3be4373f7683f0,2024-05-08
3,66abfb3f5d3be4373f7683f0,2024-05-09
4,66abfb3f5d3be4373f7683f0,2024-05-10


In [9]:
# get daily counts per topic
topic_assignments_df = pd.concat(list(topic_assignments.values()))
topics_daily_counts = topic_assignments_df.groupby(by=["topic_id", "document_date"])["document_id"].nunique()
topics_daily_counts = topics_daily_counts.reset_index()
topics_daily_counts = topics_daily_counts.rename(columns={"document_id": "count", "document_date": "date"})

# join with topics_daily_counts to fill in NAs
topics_daily_counts = pd.merge(
    topics_dates_index,
    topics_daily_counts,
    how="left",
    on=["topic_id", "date"]
)
topics_daily_counts["count"] = topics_daily_counts["count"].fillna(value=0).astype(int)

# sort topics_daily_counts for simplified transformations
topics_daily_counts = topics_daily_counts.sort_values(by=["date", "topic_id", "count"])
topics_daily_counts = topics_daily_counts.reset_index(drop=True)

topics_daily_counts.shape

(8666, 3)

In [10]:
# transform counts into a list representation per topic
dataset = topics_daily_counts.groupby(by="topic_id").agg({"date": list, "count": list})
dataset = dataset.reset_index()

# assign labels
dataset["matches_trending"] = 0
dataset.loc[dataset["topic_id"].isin(topics_matched_trending), "matches_trending"] = 1

print(dataset.shape)
display(dataset.head())

(1238, 4)


Unnamed: 0,topic_id,date,count,matches_trending
0,66abfb3f5d3be4373f7683e7,"[2024-05-06, 2024-05-07, 2024-05-08, 2024-05-0...","[22, 38, 29, 16, 28, 12, 18]",1
1,66abfb3f5d3be4373f7683e8,"[2024-05-06, 2024-05-07, 2024-05-08, 2024-05-0...","[6, 5, 13, 9, 21, 8, 8]",1
2,66abfb3f5d3be4373f7683e9,"[2024-05-06, 2024-05-07, 2024-05-08, 2024-05-0...","[13, 19, 12, 9, 7, 3, 5]",1
3,66abfb3f5d3be4373f7683ea,"[2024-05-06, 2024-05-07, 2024-05-08, 2024-05-0...","[12, 8, 8, 1, 8, 7, 6]",1
4,66abfb3f5d3be4373f7683eb,"[2024-05-06, 2024-05-07, 2024-05-08, 2024-05-0...","[9, 5, 10, 5, 11, 5, 5]",1


In [11]:
dataset["matches_trending"].value_counts()

matches_trending
0    909
1    329
Name: count, dtype: int64

In [12]:
# add topic embeddings
dataset["embedding"] = dataset["topic_id"].apply(topic_embeddings.get)

print(dataset["embedding"].isna().sum())
print(dataset.shape)
display(dataset.head())

0
(1238, 5)


Unnamed: 0,topic_id,date,count,matches_trending,embedding
0,66abfb3f5d3be4373f7683e7,"[2024-05-06, 2024-05-07, 2024-05-08, 2024-05-0...","[22, 38, 29, 16, 28, 12, 18]",1,"[0.014844002455834355, 0.13961728389523473, -0..."
1,66abfb3f5d3be4373f7683e8,"[2024-05-06, 2024-05-07, 2024-05-08, 2024-05-0...","[6, 5, 13, 9, 21, 8, 8]",1,"[0.06552150122824188, 0.22931335911588982, 0.0..."
2,66abfb3f5d3be4373f7683e9,"[2024-05-06, 2024-05-07, 2024-05-08, 2024-05-0...","[13, 19, 12, 9, 7, 3, 5]",1,"[0.10776098631322384, 0.2900739529835326, -0.0..."
3,66abfb3f5d3be4373f7683ea,"[2024-05-06, 2024-05-07, 2024-05-08, 2024-05-0...","[12, 8, 8, 1, 8, 7, 6]",1,"[0.038234669270048684, 0.13458719113430775, -0..."
4,66abfb3f5d3be4373f7683eb,"[2024-05-06, 2024-05-07, 2024-05-08, 2024-05-0...","[9, 5, 10, 5, 11, 5, 5]",1,"[0.05847633781377226, 0.2113554148375988, 0.01..."


In [13]:
def save_dataset(dataset, path):
    dataset_ = dataset.copy(deep=True)
    dataset_["topic_id"] = dataset_["topic_id"].apply(str)
    Path(path).parent.mkdir(parents=True, exist_ok=True)
    dataset_.to_json(path)

In [14]:
# checkpoint the dataset

today_str = datetime.now().strftime("%y%m%d")
dataset_path = f"../data/datasets/{today_str}_v2.json"
save_dataset(dataset, dataset_path)

# Modelling

In [15]:
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import warnings

from bson import ObjectId
from sklearn.metrics import precision_score, recall_score, f1_score, accuracy_score
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torchinfo import summary
from tqdm import tqdm

# warnings.filterwarnings("ignore")
pd.options.plotting.backend = "plotly"

In [16]:
def load_dataset(path):
    dataset_ = pd.read_json(path)
    # dataset_["topic_id"] = dataset_["topic_id"].apply(ObjectId)
    return dataset_

In [18]:
dataset_path = f"../data/datasets/240805_v2.json"
dataset_df = load_dataset(dataset_path)

print(dataset_df.shape)
display(dataset_df.head())

(1238, 5)


Unnamed: 0,topic_id,date,count,matches_trending,embedding
0,66abfb3f5d3be4373f7683e7,"[2024-05-06, 2024-05-07, 2024-05-08, 2024-05-0...","[22, 38, 29, 16, 28, 12, 18]",1,"[0.0148440025, 0.1396172839, -0.0371434999, -0..."
1,66abfb3f5d3be4373f7683e8,"[2024-05-06, 2024-05-07, 2024-05-08, 2024-05-0...","[6, 5, 13, 9, 21, 8, 8]",1,"[0.0655215012, 0.2293133591, 0.0490570989, -0...."
2,66abfb3f5d3be4373f7683e9,"[2024-05-06, 2024-05-07, 2024-05-08, 2024-05-0...","[13, 19, 12, 9, 7, 3, 5]",1,"[0.1077609863, 0.290073953, -0.0872822849, -0...."
3,66abfb3f5d3be4373f7683ea,"[2024-05-06, 2024-05-07, 2024-05-08, 2024-05-0...","[12, 8, 8, 1, 8, 7, 6]",1,"[0.0382346693, 0.1345871911, -0.118266835, -0...."
4,66abfb3f5d3be4373f7683eb,"[2024-05-06, 2024-05-07, 2024-05-08, 2024-05-0...","[9, 5, 10, 5, 11, 5, 5]",1,"[0.0584763378, 0.2113554148, 0.0159812742, -0...."


In [19]:
def train_test_split(test_pct: float = 0.2):
    num_train_samples = int(dataset_df.shape[0] * (1 - test_pct))
    num_test_samples = dataset_df.shape[0] - num_train_samples
    shuffled = dataset_df.sample(frac=1.0)
    train_df = shuffled.head(num_train_samples).reset_index(drop=True)
    test_df = shuffled.tail(num_test_samples).reset_index(drop=True)
    return train_df, test_df
    

train_df, test_df = train_test_split(test_pct=0.2)
print(f"{train_df.shape=}")
print(f"{test_df.shape=}")
assert train_df.shape[0] + test_df.shape[0] == dataset_df.shape[0]
assert len(set(train_df["topic_id"]).intersection(set(test_df["topic_id"]))) == 0

train_df.shape=(990, 5)
test_df.shape=(248, 5)


In [20]:
class Dataset(Dataset):
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        # we need to convert ObjectId into strings
        topic_id = str(self.data.iloc[idx]['topic_id'])
        counts = torch.tensor(self.data.iloc[idx]['count'], dtype=torch.float32)
        embedding = torch.tensor(self.data.iloc[idx]['embedding'], dtype=torch.float32)
        label = torch.tensor(self.data.iloc[idx]['matches_trending'], dtype=torch.float32)
        return {"topic_id": topic_id, "counts": counts, "embedding": embedding, "label": label}

In [21]:
train_dataset = Dataset(train_df)
test_dataset = Dataset(test_df)

In [22]:
class Model(nn.Module):
    def __init__(self, embedding_dim, counts_dim, counts_hidden_size):
        super(Model, self).__init__()
        self.lstm = nn.LSTM(input_size=1, hidden_size=counts_hidden_size, batch_first=True)
        self.fc_time = nn.Linear(counts_hidden_size, embedding_dim)
        self.fc_topic = nn.Linear(embedding_dim, embedding_dim)
        self.fc_combined = nn.Linear(embedding_dim * 2, 1)

    def forward(self, counts, topic_emb):
        # Reshape counts for LSTM
        counts = counts.unsqueeze(-1)  # Shape: [batch_size, counts_dim, 1]
        
        # Process counts with LSTM
        lstm_out, _ = self.lstm(counts)
        
        # Use the output from the last time step
        time_emb = lstm_out[:, -1, :]
        
        # Transform time_emb to match embedding dimensions
        time_emb = F.relu(self.fc_time(time_emb))

        # Transform topic embeddings
        topic_emb_processed = F.relu(self.fc_topic(topic_emb))
        
        # Combine the embeddings
        combined_emb = torch.cat((time_emb, topic_emb_processed), dim=1)
        
        # Classification
        output = self.fc_combined(combined_emb)
        return output

In [23]:
counts_hidden_size = 32
embedding_dim = train_dataset[0]["embedding"].shape[-1]
counts_dim = len(train_dataset[0]["counts"])

model = Model(embedding_dim=embedding_dim, counts_dim=counts_dim, counts_hidden_size=counts_hidden_size)
summary(model)

Layer (type:depth-idx)                   Param #
Model                                    --
├─LSTM: 1-1                              4,480
├─Linear: 1-2                            12,672
├─Linear: 1-3                            147,840
├─Linear: 1-4                            769
Total params: 165,761
Trainable params: 165,761
Non-trainable params: 0

In [24]:
def train_epoch(model, dataloader, criterion, optimizer):
    model.train()
    running_loss = 0.0

    for batch in dataloader:
        optimizer.zero_grad()
        logits = model(batch["counts"], batch["embedding"])
        labels = batch["label"]
        loss = criterion(logits.squeeze(), labels)
        running_loss += loss.item()
        loss.backward()
        optimizer.step()

    avg_loss = running_loss / len(dataloader)
    return avg_loss


def evaluate_model(model, dataloader, criterion, prefix: str = ""):
    model.eval()
    all_preds = []
    all_labels = []
    running_loss = 0.0
    
    with torch.no_grad():
        for batch in dataloader:
            logits = model(batch["counts"], batch["embedding"])
            preds = torch.sigmoid(logits).squeeze() > 0.5
            labels = batch["label"]

            # Compute loss
            loss = criterion(logits.squeeze(), labels)
            running_loss += loss.item()
        
            # Collect all predictions and labels
            all_preds.extend(preds.cpu().numpy().tolist())
            all_labels.extend(labels.cpu().numpy().tolist())

    predictions = np.array(all_preds)
    actuals = np.array(all_labels)

    response = {
        "loss": running_loss / len(dataloader),
        "precision": precision_score(predictions, actuals),
        "recall": recall_score(predictions, actuals),
        "f1": f1_score(predictions, actuals),
        "accuracy": accuracy_score(predictions, actuals),
    }
    if prefix:
        response = {f"{prefix}{key}": value for key, value in response.items()}
    return response

In [None]:
train_batch_size = 16
test_batch_size = 16
num_epochs = 100
learning_rate = 1e-4
weight_decay = 1e-3

criterion = nn.BCEWithLogitsLoss()
optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=learning_rate,
    weight_decay=weight_decay,
)

# Training loop
train_log = []
for epoch in tqdm(range(num_epochs), position=0, leave=True):
    train_dataloader = DataLoader(train_dataset, batch_size=train_batch_size, shuffle=True)
    train_loss = train_epoch(model, train_dataloader, criterion, optimizer)

    train_dataloader = DataLoader(train_dataset, batch_size=test_batch_size)
    train_metrics = evaluate_model(model, train_dataloader, criterion, "train_")

    test_dataloader = DataLoader(test_dataset, batch_size=test_batch_size)
    test_metrics = evaluate_model(model, test_dataloader, criterion, "test_")
    train_log.append({"epoch": epoch, **train_metrics, **test_metrics})

    if epoch % 5 == 0:
        display(pd.Series(train_log[-1]).to_frame().T)

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Unnamed: 0,epoch,train_loss,train_precision,train_recall,train_f1,train_accuracy,test_loss,test_precision,test_recall,test_f1,test_accuracy
0,0.0,0.606721,0.0,0.0,0.0,0.737374,0.615993,0.0,0.0,0.0,0.721774


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  5%|██▏                                        | 5/100 [00:03<01:01,  1.54it/s]

Unnamed: 0,epoch,train_loss,train_precision,train_recall,train_f1,train_accuracy,test_loss,test_precision,test_recall,test_f1,test_accuracy
0,5.0,0.489148,0.226923,0.7375,0.347059,0.775758,0.517181,0.173913,0.705882,0.27907,0.75


 10%|████▏                                     | 10/100 [00:05<00:56,  1.60it/s]

Unnamed: 0,epoch,train_loss,train_precision,train_recall,train_f1,train_accuracy,test_loss,test_precision,test_recall,test_f1,test_accuracy
0,10.0,0.455564,0.338462,0.739496,0.46438,0.794949,0.493666,0.318841,0.733333,0.444444,0.778226


 15%|██████▎                                   | 15/100 [00:09<01:04,  1.32it/s]

Unnamed: 0,epoch,train_loss,train_precision,train_recall,train_f1,train_accuracy,test_loss,test_precision,test_recall,test_f1,test_accuracy
0,15.0,0.43935,0.419231,0.736486,0.534314,0.808081,0.481383,0.391304,0.627907,0.482143,0.766129


 20%|████████▍                                 | 20/100 [00:12<00:46,  1.72it/s]

Unnamed: 0,epoch,train_loss,train_precision,train_recall,train_f1,train_accuracy,test_loss,test_precision,test_recall,test_f1,test_accuracy
0,20.0,0.425646,0.423077,0.738255,0.537897,0.809091,0.476255,0.376812,0.65,0.477064,0.770161


 25%|██████████▌                               | 25/100 [00:14<00:34,  2.20it/s]

Unnamed: 0,epoch,train_loss,train_precision,train_recall,train_f1,train_accuracy,test_loss,test_precision,test_recall,test_f1,test_accuracy
0,25.0,0.415621,0.453846,0.7375,0.561905,0.814141,0.470739,0.391304,0.658537,0.490909,0.774194


 30%|████████████▌                             | 30/100 [00:17<00:30,  2.31it/s]

Unnamed: 0,epoch,train_loss,train_precision,train_recall,train_f1,train_accuracy,test_loss,test_precision,test_recall,test_f1,test_accuracy
0,30.0,0.410867,0.542308,0.734375,0.623894,0.828283,0.464162,0.449275,0.645833,0.529915,0.778226


 35%|██████████████▋                           | 35/100 [00:19<00:26,  2.44it/s]

Unnamed: 0,epoch,train_loss,train_precision,train_recall,train_f1,train_accuracy,test_loss,test_precision,test_recall,test_f1,test_accuracy
0,35.0,0.399408,0.496154,0.767857,0.602804,0.828283,0.460984,0.434783,0.681818,0.530973,0.78629


 38%|███████████████▉                          | 38/100 [00:20<00:27,  2.27it/s]

In [None]:
train_log = pd.DataFrame(train_log).set_index("epoch")
train_log

In [None]:
metric = "loss"
train_log[[f"train_{metric}", f"test_{metric}"]].plot(title=metric.capitalize())

In [None]:
metric = "accuracy"
train_log[[f"train_{metric}", f"test_{metric}"]].plot(title=metric.capitalize())

In [None]:
metric = "recall"
train_log[[f"train_{metric}", f"test_{metric}"]].plot(title=metric.capitalize())

In [None]:
metric = "precision"
train_log[[f"train_{metric}", f"test_{metric}"]].plot(title=metric.capitalize())

In [None]:
metric = "f1"
train_log[[f"train_{metric}", f"test_{metric}"]].plot(title=metric.capitalize())