In [None]:
import os
# Check if you're on Google drive or on your own machine.
# Get path to your data.
if ('google' in str(get_ipython())):
    from google.colab import drive
    drive.mount('ME', force_remount=True)
    predir='ME/MyDrive/Colab_Notebooks/thesis'

In [None]:
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import pandas as pd
import numpy as np
import warnings
from tqdm import tqdm
import pickle

from sklearn.preprocessing import StandardScaler
import torch
import torch.nn as nn

import os
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"

In [None]:
from variables import groups,group_order,groupmap,group_titles,finished_usernames

In [None]:
with open(predir+'/data/df_bypost_all.pkl', 'rb') as f:
    df_bypost = pickle.load(f)

df_bypost = df_bypost.dropna(subset=['post_times']).query('likes > 0')[df_bypost['username'].isin(finished_usernames)]

# Extract cyclical time features
def fourier_encode(df):
    df['hour'] = df['post_times'].dt.hour
    df['day_of_week'] = df['post_times'].dt.dayofweek  # 0=Monday
    df['month'] = df['post_times'].dt.month
    df['hour_sin'] = np.sin(2 * np.pi * df['hour'] / 24)
    df['hour_cos'] = np.cos(2 * np.pi * df['hour'] / 24)
    df['day_sin'] = np.sin(2 * np.pi * df['day_of_week'] / 7)
    df['day_cos'] = np.cos(2 * np.pi * df['day_of_week'] / 7)
    df['month_sin'] = np.sin(2 * np.pi * df['month'] / 12)
    df['month_cos'] = np.cos(2 * np.pi * df['month'] / 12)
    return df

data = fourier_encode(df_bypost)


# Time since last post (in hours)
data['timedelta'] = data['timedelta'].fillna(-1)  # First post

data['log_likes'] = np.log(data['likes']+1)

# Rolling average of likes (past 10 posts)
data['rolling_likes'] = (
    data.groupby('username')['log_likes']
    .transform(lambda x: x.shift(1).rolling(14, min_periods=1).mean())
)

# get the rate of change using np.gradient
data['deriv1'] = np.gradient(data['rolling_likes'])
data['deriv2'] = np.gradient(data['deriv1'])

# make dummy variable for username
dummies = pd.get_dummies(data['username'], columns=['username'], prefix = "username")
data = pd.concat([data, dummies], axis=1)

dummies = pd.get_dummies(data['group'], columns=['group'], prefix = 'group')
data = pd.concat([data, dummies], axis=1)

In [None]:
REACHBACK = 14
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
split_before = pd.to_datetime('2024-07-01')
data_before = data[data['date'] < split_before]
data_before = data[data['date'] > pd.to_datetime('2023-11-01')]

split = pd.to_datetime('2024-08-01')
data_after = data[data['date'] >= split]
data_after = data_after[data_after['date'] <= pd.to_datetime('2024-11-01')]

# for the data after, only include [reachback:] for each account
data_after = data_after.groupby('username').apply(lambda x: x.iloc[REACHBACK:]).reset_index(drop=True)

# models


In [None]:
def df_to_traintest(df, PrepData):
  dataprepper = PrepData(df, reachback_length = REACHBACK)
  dataset = [x for x in dataprepper.get_dataset()]

  X, y = torch.cat([x[0] for x in dataset], dim=0), torch.tensor([x[1] for x in dataset])
  X = X.reshape(len(dataset), dataprepper.reachback_length*X.shape[1])

  train, val = torch.utils.data.random_split(list(zip(X,y)), [int(0.8*len(dataset)), len(dataset) - int(0.8*len(dataset))])

  print(f"Train size: {len(train)}")
  print(f"Val size: {len(val)}")

  trainx, trainy = np.vstack([x[0] for x in train]), np.array([x[1] for x in train])
  valx, valy = np.vstack([x[0] for x in val]), np.array([x[1] for x in val])
  # testx, testy = np.vstack([x[0] for x in test]), np.array([x[1] for x in test])

  # Convert numpy arrays to PyTorch datasets
  train_dataset = torch.utils.data.TensorDataset(torch.tensor(trainx, dtype=torch.float32),
                                                torch.tensor(trainy, dtype=torch.float32))
  val_dataset = torch.utils.data.TensorDataset(torch.tensor(valx, dtype=torch.float32),
                                              torch.tensor(valy, dtype=torch.float32))

  # Create data loaders
  batch_size = 64
  train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
  val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size)

  return train_loader, val_loader

def data_loader(df,PrepData, reachback_length=REACHBACK):
    dataprepper = PrepData(df, single_user=True, reachback_length=reachback_length)
    dataset = [x for x in dataprepper.get_dataset()]
    x, y = np.vstack([x[0] for x in dataset]), np.array([x[1] for x in dataset])
    x = x.reshape(len(dataset), dataprepper.reachback_length*x.shape[1])

    dataset = torch.utils.data.TensorDataset(torch.tensor(x, dtype=torch.float32),
                                             torch.tensor(y, dtype=torch.float32))
    batch_size = 64
    loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size)
    return loader

In [None]:
class OneBranchPrepData():
    def __init__(self, df, reachback_length=REACHBACK,single_user=False):
        self.df = df
        self.reachback_length = reachback_length
        self.users = df['username'].unique()
        self.single_user = single_user

        # Precompute reachback
        self.reachback = []
        if self.single_user:
          for i in range(len(self.df) - (self.reachback_length)):
              seq = self.df.iloc[i:i+self.reachback_length]
              target = self.df.iloc[i+self.reachback_length]['likes']
              target = np.log(target+1)
              self.reachback.append((seq, target))
        else:
          for user in self.users:
              user_df = self.df[self.df['username'] == user]
              for i in range(len(user_df) - (self.reachback_length)):
                  seq = user_df.iloc[i:i+self.reachback_length]
                  target = user_df.iloc[i+self.reachback_length]['likes']
                  target = np.log(target+1)
                  self.reachback.append((seq, target))

    def __getitem__(self, idx):
        reachback, target = self.reachback[idx]

        # Text embeddings (reachback_length x 384)
        text_features = torch.tensor(
            np.stack(reachback['caption_embedding'].values),
            dtype=torch.float32
        )

        # # Time features (reachback_length x 6)
        # time_features = torch.tensor(
        #     reachback[['hour_sin', 'hour_cos', 'day_sin', 'day_cos', 'month_sin', 'month_cos']].values,
        #     dtype=torch.float32
        # )

        # Historical features (reachback_length x 2)
        # historical_features = torch.tensor(
        #     reachback[['rolling_likes','timedelta']].values,
        #     dtype=torch.float32
        # )

        X = torch.cat([text_features], dim=1)
        return X, target

    def get_dataset(self):
      if self.single_user:
          for i in range(len(self.reachback)):
              value = self.__getitem__(i)
              # check if there is any nan
              if value[0].isnan().any():
                  continue
              else:
                yield value
      else:
          for user in self.users:
              user_df = self.df[self.df['username'] == user]
              for i in range(len(user_df) - self.reachback_length):
                  value = self.__getitem__(i)
                  # check if there is any nan
                  if value[0].isnan().any():
                      continue
                  else:
                    yield value

class OneBranchLikesPredictor(nn.Module):
    def __init__(self, sbert_dim=384, time_dim=6, historical_dim=0,
                 hidden_dim=64, reachback_length=REACHBACK,
                 verobose=False):
        super().__init__()
        self.reachback_length = reachback_length
        self.sbert_dim = sbert_dim
        self.time_dim = time_dim
        self.historical_dim = historical_dim

        self.verbose = verobose

        # Text processing
        self.text_fc = nn.Linear(sbert_dim, 128) # play with output dimension

        # LSTM for temporal/historical patterns
        self.lstm = nn.LSTM(
            input_size= 128 + historical_dim,
            hidden_size=hidden_dim,
            batch_first=True
        )

        # Final prediction
        self.fc = nn.Sequential(
            nn.Linear(64, 32),
            nn.ReLU(),
            nn.Linear(32, 1)
        )

    def forward(self, x):
        # Reshape input: (batch_size, reachback_length * total_features) ->
        # (batch_size, reachback_length, total_features)
        x = x.view(-1, self.reachback_length, self.sbert_dim + self.historical_dim)

        # Split features
        text_features = x[:, :, :self.sbert_dim]
        # historical_features = x[:, :, self.sbert_dim:self.sbert_dim + self.historical_dim]
        # historical_features = x[:, :, self.sbert_dim+self.time_dim:self.sbert_dim + self.historical_dim]

        if self.verbose:
            print(f"Text features shape: {text_features.shape}")
            # print(f"Historical features shape: {historical_features.shape}")

        # Process text (batch_size, reachback_length, 128)
        text_emb = self.text_fc(text_features)

        if self.verbose:
            print(f"text fc: {text_emb.shape}")

        # Process temporal features (batch_size, reachback_length, 6)
        temporal_input = torch.cat([text_emb], dim=2)
        lstm_out, _ = self.lstm(temporal_input)
        lstm_last = lstm_out[:, -1, :]  # Last timestep

        return self.fc(lstm_last)

In [None]:
class TwoBranchPrepData():
    def __init__(self, df, reachback_length=REACHBACK,single_user=False):
        self.df = df
        self.reachback_length = reachback_length
        self.users = df['username'].unique()
        self.single_user = single_user

        # Precompute reachback
        self.reachback = []
        if self.single_user:
          for i in range(len(self.df) - (self.reachback_length)):
              seq = self.df.iloc[i:i+self.reachback_length]
              target = self.df.iloc[i+self.reachback_length]['likes']
              target = np.log(target+1)
              self.reachback.append((seq, target))
        else:
          for user in self.users:
              user_df = self.df[self.df['username'] == user]
              for i in range(len(user_df) - (self.reachback_length)):
                  seq = user_df.iloc[i:i+self.reachback_length]
                  target = user_df.iloc[i+self.reachback_length]['likes']
                  target = np.log(target+1)
                  self.reachback.append((seq, target))

    def __getitem__(self, idx):
        reachback, target = self.reachback[idx]

        # Text embeddings (reachback_length x 384)
        text_features = torch.tensor(
            np.stack(reachback['caption_embedding'].values),
            dtype=torch.float32
        )

        # Time features (reachback_length x 6)
        time_features = torch.tensor(
            reachback[['hour_sin', 'hour_cos', 'day_sin', 'day_cos', 'month_sin', 'month_cos']].values,
            dtype=torch.float32
        )

        # Historical features (reachback_length x 4)
        historical_features = torch.tensor(
            reachback[['rolling_likes', 'timedelta','deriv1','deriv2']].values,
            dtype=torch.float32
        )

        X = torch.cat([text_features, time_features, historical_features], dim=1)
        return X, target

    def get_dataset(self):
      if self.single_user:
          for i in range(len(self.reachback)):
              value = self.__getitem__(i)
              # check if there is any nan
              if value[0].isnan().any():
                  continue
              else:
                yield value
      else:
          for user in self.users:
              user_df = self.df[self.df['username'] == user]
              for i in range(len(user_df) - self.reachback_length):
                  value = self.__getitem__(i)
                  # check if there is any nan
                  if value[0].isnan().any():
                      continue
                  else:
                    yield value

class TwoBranchLikesPredictor(nn.Module):
    def __init__(self, sbert_dim=384, time_dim=6, historical_dim=4,
                 hidden_dim=64, reachback_length=REACHBACK,
                 verobose=False):
        super().__init__()
        self.reachback_length = reachback_length
        self.sbert_dim = sbert_dim
        self.time_dim = time_dim
        self.historical_dim = historical_dim

        self.verbose = verobose

        # Text processing
        self.text_fc = nn.Linear(sbert_dim, 128) # play with output dimension
        self.text_lstm = nn.LSTM(
            input_size=128,
            hidden_size=64,
            batch_first=True
        )

        # LSTM for temporal/historical patterns
        self.lstm = nn.LSTM(
            input_size=time_dim + historical_dim,
            hidden_size=hidden_dim,
            batch_first=True
        )

        # Final prediction
        self.fc = nn.Sequential(
            nn.Linear(64 + hidden_dim, 64),
            nn.ReLU(),
            nn.Linear(64, 1)
        )

    def forward(self, x):
        # Reshape input: (batch_size, reachback_length * total_features) ->
        # (batch_size, reachback_length, total_features)
        x = x.view(-1, self.reachback_length, self.sbert_dim + self.time_dim + self.historical_dim)

        # Split features
        text_features = x[:, :, :self.sbert_dim]
        time_features = x[:, :, self.sbert_dim:self.sbert_dim+self.time_dim]
        historical_features = x[:, :, self.sbert_dim+self.time_dim:self.sbert_dim+self.time_dim + self.historical_dim]

        if self.verbose:
            print(f"Text features shape: {text_features.shape}")
            print(f"Time features shape: {time_features.shape}")
            print(f"Historical features shape: {historical_features.shape}")

        # Process text (batch_size, reachback_length, 128)
        text_emb = self.text_fc(text_features)
        text_lstm = self.text_lstm(text_emb)
        text_last = text_lstm[0][:, -1, :]  # Last timestep
        # text_agg = torch.mean(text_emb, dim=1)  # Average over sequence

        if self.verbose:
            print(f"text fc: {text_emb.shape}")
            print(f"Text agg shape: {text_last.shape}")

        # Process temporal features (batch_size, reachback_length, 6)
        temporal_input = torch.cat([time_features, historical_features], dim=2)
        lstm_out, _ = self.lstm(temporal_input)
        lstm_last = lstm_out[:, -1, :]  # Last timestep

        if self.verbose:
            print(f"LSTM output shape: {lstm_last.shape}")

        # Combine features
        combined = torch.cat([text_last, lstm_last], dim=1)

        if self.verbose:
            print(f"Combined shape: {combined.shape}")

        return self.fc(combined)

In [None]:
class ThreeBranchPrepData():
    def __init__(self, df, reachback_length=REACHBACK,single_user=False):
        self.df = df
        self.reachback_length = reachback_length
        self.users = df['username'].unique()
        self.reachback = []

        for user in self.users:
            user_df = self.df[self.df['username'] == user]
            for i in range(len(user_df) - (self.reachback_length)):
                seq = user_df.iloc[i:i+self.reachback_length]
                target = user_df.iloc[i+self.reachback_length]['likes']
                target = np.log(target+1)
                self.reachback.append((seq, target))

    def __getitem__(self, idx):
        reachback, target = self.reachback[idx]

        # Text embeddings (reachback_length x 384)
        text_features = torch.tensor(
            np.stack(reachback['caption_embedding'].values),
            dtype=torch.float32
        )

        # Time features (reachback_length x 6)
        time_features = torch.tensor(
            reachback[['hour_sin', 'hour_cos', 'day_sin', 'day_cos', 'month_sin', 'month_cos']].values,
            dtype=torch.float32
        )

        # Historical features (reachback_length x 2)
        historical_features = torch.tensor(
            reachback[['rolling_likes', 'timedelta','deriv1','deriv2']].values,
            dtype=torch.float32
        )

        # username dummy (reachback_length x 312)
        username_features = torch.tensor(
            reachback[reachback.columns[reachback.columns.str.startswith('username_')]].values,
            dtype=torch.float32
        )

        X = torch.cat([text_features, time_features, historical_features, username_features], dim=1)
        return X, target

    def get_dataset(self):
        for user in self.users:
            user_df = self.df[self.df['username'] == user]
            for i in range(len(user_df) - self.reachback_length):
                value = self.__getitem__(i)
                # check if there is any nan
                if value[0].isnan().any():
                    continue
                else:
                  yield value
class ThreeBranchLikesPredictor(nn.Module):
    def __init__(self, sbert_dim=384, time_dim=6, historical_dim=4,
                 hidden_dim=64, reachback_length=REACHBACK,
                 verobose=False):
        super().__init__()
        self.reachback_length = reachback_length
        self.sbert_dim = sbert_dim
        self.time_dim = time_dim
        self.historical_dim = historical_dim
        self.username_dim = 312

        self.verbose = verobose

        # Text processing
        self.text_fc = nn.Linear(sbert_dim, 128) # play with output dimension
        self.text_lstm = nn.LSTM(
            input_size=128,
            hidden_size=64,
            batch_first=True
        )

        # LSTM for temporal/historical patterns
        self.lstm = nn.LSTM(
            input_size=time_dim + historical_dim,
            hidden_size=hidden_dim,
            batch_first=True
        )

        # linear for username
        self.username_fc = nn.Linear(312, 1)

        # Final prediction
        self.fc = nn.Sequential(
            nn.Linear(64 + hidden_dim + 1, 64),
            nn.ReLU(),
            nn.Linear(64, 1)
        )

    def forward(self, x):
        # Reshape input: (batch_size, reachback_length * total_features) ->
        # (batch_size, reachback_length, total_features)
        x = x.view(-1, self.reachback_length, self.sbert_dim + self.time_dim + self.historical_dim + self.username_dim)

        # Split features
        text_features = x[:, :, :self.sbert_dim]
        time_features = x[:, :, self.sbert_dim:self.sbert_dim+self.time_dim]
        historical_features = x[:, :, self.sbert_dim+self.time_dim:self.sbert_dim+self.time_dim + self.historical_dim]
        username_features = x[:, :, -self.username_dim:]

        if self.verbose:
            print(f"Text features shape: {text_features.shape}")
            print(f"Time features shape: {time_features.shape}")
            print(f"Historical features shape: {historical_features.shape}")
            print(f"Username features shape: {username_features.shape}")

        # Process text (batch_size, reachback_length, 128)
        text_emb = self.text_fc(text_features)
        text_lstm = self.text_lstm(text_emb)
        text_last = text_lstm[0][:, -1, :]  # Last timestep
        # text_agg = torch.mean(text_emb, dim=1)  # Average over sequence

        if self.verbose:
            print(f"text fc: {text_emb.shape}")
            print(f"Text agg shape: {text_last.shape}")

        # Process temporal features (batch_size, reachback_length, 6)
        temporal_input = torch.cat([time_features, historical_features], dim=2)
        lstm_out, _ = self.lstm(temporal_input)
        lstm_last = lstm_out[:, -1, :]  # Last timestep

        if self.verbose:
            print(f"LSTM output shape: {lstm_last.shape}")

        # Process username (batch_size, reachback_length, )
        username_emb = self.username_fc(username_features)
        username_agg = torch.mean(username_emb, dim=1)  # Average over sequence

        # Combine features
        combined = torch.cat([text_last, lstm_last, username_agg], dim=1)

        if self.verbose:
            print(f"Combined shape: {combined.shape}")


        return self.fc(combined)

In [None]:
def train_model(model, train_loader, val_loader, lr=0.001, num_epochs= 10, verbose = False):
  optimizer = torch.optim.Adam(model.parameters(), lr=lr)
  criterion = nn.MSELoss()

  train_loss_history = []
  val_loss_history = []
  for epoch in tqdm(range(num_epochs)):
      # Training phase
      model.train()
      train_loss = 0.0
      for inputs, targets in train_loader:
          inputs = inputs.to(device)
          targets = targets.to(device).float()
          optimizer.zero_grad()

          outputs = model(inputs)
          loss = criterion(outputs.squeeze(), targets)
          loss.backward()
          optimizer.step()

          train_loss += loss.item() * inputs.size(0)

      # Calculate average training loss
      train_loss = train_loss / len(train_loader.dataset)

      # Validation phase
      model.eval()
      val_loss = 0.0
      with torch.no_grad():
          for inputs, targets in val_loader:
              inputs = inputs.to(device)
              targets = targets.to(device).float()
              outputs = model(inputs)
              loss = criterion(outputs.squeeze(), targets)
              val_loss += loss.item() * inputs.size(0)

      val_loss = val_loss / len(val_loader.dataset)

      train_loss_history.append(train_loss)
      val_loss_history.append(val_loss)

      if verbose:
        print(f'Epoch {epoch+1}/{num_epochs}')
        print(f'Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f}')
        print('-' * 50)
  return model, train_loss_history, val_loss_history

def test_model(model, test_loader):
  criterion = nn.MSELoss()
  model.eval()
  test_loss = 0.0
  test_loss_unlog = 0.0
  with torch.no_grad():
      for inputs, targets in test_loader:
          outputs = model(inputs)
          loss = criterion(outputs.squeeze(), targets)
          test_loss += loss.item() * inputs.size(0)

          outputs_unlog = torch.exp(outputs)
          targets_unlog = torch.exp(targets)
          loss_unlog = criterion(outputs_unlog.squeeze(), targets_unlog)
          test_loss_unlog += loss_unlog.item() * inputs.size(0)

  test_loss = test_loss / len(test_loader.dataset)
  test_loss_unlog = test_loss_unlog / len(test_loader.dataset)

  return test_loss, test_loss_unlog

# data

In [None]:
# Simple

trainval_grouped1 = {}
for group, usernames in groups.items():
  group_df = data_before.query('username in @usernames')
  print(group, len(group_df), '-'*30)
  if len(group_df) > 100:
    trainval_grouped1[group] = df_to_traintest(group_df, PrepData = OneBranchPrepData)
  else:
    print(f'{group} not enough data')
  # trainval_grouped[group] = df_to_traintest(group_df)

In [None]:
# two branch

trainval_grouped2 = {}
for group, usernames in groups.items():
  group_df = data_before.query('username in @usernames')
  print(group, len(group_df), '-'*30)
  if len(group_df) > 100:
    trainval_grouped2[group] = df_to_traintest(group_df, TwoBranchPrepData)
  else:
    print(f'{group} not enough data')
  # trainval_grouped[group] = df_to_traintest(group_df)

In [None]:
# three branch

trainval_grouped3 = {}
for group, usernames in groups.items():
  group_df = data_before.query('username in @usernames')
  print(group, len(group_df), '-'*30)
  if len(group_df) > 100:
    trainval_grouped3[group] = df_to_traintest(group_df, ThreeBranchPrepData)
  else:
    print(f'{group} not enough data')
  # trainval_grouped[group] = df_to_traintest(group_df)

# train

In [None]:
model_grouped1 = {}
for group, (train_loader, val_loader) in trainval_grouped1.items():
  print(f"Training {group}")
  model = OneBranchLikesPredictor(
      sbert_dim=384,
      reachback_length=REACHBACK,
      verobose=False
  )

  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  # device = torch.device('cpu')
  model = model.to(device)
  print(f"Using device: {device}")

  model, train_loss_history, val_loss_history = train_model(model, train_loader, val_loader, lr=0.001, verbose=False, num_epochs=25)

  model_grouped1[group] = (model, train_loss_history, val_loss_history)

In [None]:
for i,group in enumerate(model_grouped1.keys()):
  model, train_loss_history, val_loss_history = model_grouped1[group]
  plt.subplot(3,5,i+1)
  plt.text(0.5, 0.5, f'{group}:\n{val_loss_history[-1]:.2f}', ha='center', va='center', transform=plt.gca().transAxes)
  plt.plot(train_loss_history, label='Train Loss', )
  plt.scatter(range(len(train_loss_history)), train_loss_history, color='tab:blue', s=5)
  plt.plot(val_loss_history, label='Validation Loss')
  plt.scatter(range(len(val_loss_history)), val_loss_history, color='tab:orange', s= 5)
  if i % 5 != 0:
    plt.yticks([])
  else:
    plt.ylabel('Loss')
  plt.xlabel('Epoch')

plt.legend(loc = 'lower right', bbox_to_anchor=(2.1,0.3))
plt.show()

In [None]:
# model_grouped2deriv = {}
# for group, (train_loader, val_loader) in trainval_grouped2deriv.items():
#   print(f"Training {group}")
#   model = TwoBranchDerivLikesPredictor(
#       sbert_dim=384,
#       reachback_length=REACHBACK,
#       verobose=False
#   )

#   device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
#   # device = torch.device('cpu')
#   model = model.to(device)
#   print(f"Using device: {device}")

#   model, train_loss_history, val_loss_history = train_model(model, train_loader, val_loader, lr=0.001, verbose=False, num_epochs=20)

#   model_grouped2deriv[group] = (model, train_loss_history, val_loss_history)

In [None]:
# for i,group in enumerate(model_grouped2deriv.keys()):
#   model, train_loss_history, val_loss_history = model_grouped2deriv[group]
#   plt.subplot(3,5,i+1)
#   plt.text(0.5, 0.5, f'{group}:\n{val_loss_history[-1]:.2f}', ha='center', va='center', transform=plt.gca().transAxes)
#   plt.plot(train_loss_history, label='Train Loss', )
#   plt.scatter(range(len(train_loss_history)), train_loss_history, color='tab:blue', s=5)
#   plt.plot(val_loss_history, label='Validation Loss')
#   plt.scatter(range(len(val_loss_history)), val_loss_history, color='tab:orange', s= 5)
#   if i % 5 != 0:
#     plt.yticks([])
#   else:
#     plt.ylabel('Loss')
#   plt.xlabel('Epoch')

# plt.legend(loc = 'lower right', bbox_to_anchor=(2.1,0.3))
# plt.show()

In [None]:
model_grouped2 = {}
for group, (train_loader, val_loader) in trainval_grouped2.items():
  print(f"Training {group}")
  model = TwoBranchLikesPredictor(
      sbert_dim=384,
      reachback_length=REACHBACK,
      verobose=False
  )

  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  # device = torch.device('cpu')
  model = model.to(device)
  print(f"Using device: {device}")

  model, train_loss_history, val_loss_history = train_model(model, train_loader, val_loader, lr=0.001, verbose=False, num_epochs=25)

  model_grouped2[group] = (model, train_loss_history, val_loss_history)

In [None]:
for i,group in enumerate(model_grouped2.keys()):
  model, train_loss_history, val_loss_history = model_grouped2[group]
  plt.subplot(3,5,i+1)
  plt.text(0.5, 0.5, f'{group}:\n{val_loss_history[-1]:.2f}', ha='center', va='center', transform=plt.gca().transAxes)
  plt.plot(train_loss_history, label='Train Loss', )
  plt.scatter(range(len(train_loss_history)), train_loss_history, color='tab:blue', s=5)
  plt.plot(val_loss_history, label='Validation Loss')
  plt.scatter(range(len(val_loss_history)), val_loss_history, color='tab:orange', s= 5)
  if i % 5 != 0:
    plt.yticks([])
  else:
    plt.ylabel('Loss')
  plt.xlabel('Epoch')

plt.legend(loc = 'lower right', bbox_to_anchor=(2.1,0.3))
plt.show()

In [None]:
model_grouped3 = {}
for group, (train_loader, val_loader) in trainval_grouped3.items():
  print(f"Training {group}")
  model = ThreeBranchLikesPredictor(
      sbert_dim=384,
      reachback_length=REACHBACK,
      verobose=False
  )

  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  # device = torch.device('cpu')
  model = model.to(device)
  print(f"Using device: {device}")

  model, train_loss_history, val_loss_history = train_model(model, train_loader, val_loader, lr=0.001, verbose=False, num_epochs=25)

  model_grouped3[group] = (model, train_loss_history, val_loss_history)

In [None]:
for i,group in enumerate(model_grouped3.keys()):
  model, train_loss_history, val_loss_history = model_grouped3[group]
  plt.subplot(3,5,i+1)
  plt.text(0.5, 0.5, f'{group}:\n{val_loss_history[-1]:.2f}', ha='center', va='center', transform=plt.gca().transAxes)
  plt.plot(train_loss_history, label='Train Loss', )
  plt.scatter(range(len(train_loss_history)), train_loss_history, color='tab:blue', s=5)
  plt.plot(val_loss_history, label='Validation Loss')
  plt.scatter(range(len(val_loss_history)), val_loss_history, color='tab:orange', s= 5)
  if i % 5 != 0:
    plt.yticks([])
  else:
    plt.ylabel('Loss')
  plt.xlabel('Epoch')

plt.legend(loc = 'lower right', bbox_to_anchor=(2.1,0.3))
plt.show()

In [None]:
loss = {}
# loss['one'] = [1.3042222829729941, 0.1322471123636643, 0.06989462532554612, 0.4563507162349325, 0.4942909843921661, 0.3941170771916707, 1.0540281456870002, 0.1729786756331027, 0.49412976740794595, 0.9720966219902039, 0.475496735415894, 0.5845634161707867, 0.785887829308371, 5.336733749934605]
for branches,architecture in zip(['three'],[model_grouped3]):
  final_loss = []
  for group, (model, train_loss_history, val_loss_history) in architecture.items():
    final_loss += [val_loss_history[-1]]
    loss[branches] = final_loss
print(loss)

In [None]:
loss = {}
# loss['one'] = [1.3042222829729941, 0.1322471123636643, 0.06989462532554612, 0.4563507162349325, 0.4942909843921661, 0.3941170771916707, 1.0540281456870002, 0.1729786756331027, 0.49412976740794595, 0.9720966219902039, 0.475496735415894, 0.5845634161707867, 0.785887829308371, 5.336733749934605]
for branches,architecture in zip(['one','two'],[model_grouped1, model_grouped2]):
  final_loss = []
  for group, (model, train_loss_history, val_loss_history) in architecture.items():
    final_loss += [val_loss_history[-1]]
    loss[branches] = final_loss
print(loss)

In [None]:
loss = {'one': [0.8119317060648267, 0.09030180714887104, 0.05201016272190201, 0.17431352507205447, 0.6155957554249053, 0.23071729640165964, 0.6487047892531488, 0.14092928212549952, 0.38789024254513127, 1.0574390888214111, 0.8770792104417177, 0.6949891581010789, 0.11403568857506856, 0.44171265612787275], 'two': [0.9232050452789251, 0.0576316925134286, 0.03779425060759499, 0.15334679112587432, 0.16287271551629331, 0.08917510806111256, 0.6494924682008345, 0.1172322551569631, 0.3037592445751363, 0.8925648927688599, 0.5258060012818917, 0.6455977026436447, 0.04819292140681906, 0.2598202693642992]}
loss['three'] =  [0.6245497078356081, 0.09411754319691115, 0.03141997889968589, 0.13235050077965715, 0.21072650399613888, 0.11033004692178808, 0.5246508916219076, 0.08843205912274403, 0.3231729140820822, 1.8471211194992065, 0.505800309002831, 0.6185696541775391, 0.05751745161684723, 0.30766715481877327]


In [None]:
# loss = {}
# # loss['one'] = [1.3042222829729941, 0.1322471123636643, 0.06989462532554612, 0.4563507162349325, 0.4942909843921661, 0.3941170771916707, 1.0540281456870002, 0.1729786756331027, 0.49412976740794595, 0.9720966219902039, 0.475496735415894, 0.5845634161707867, 0.785887829308371, 5.336733749934605]
# for branches,architecture in zip(['one','two','three'],[model_grouped1, model_grouped2, model_grouped3]):
#   final_loss = []
#   for group, (model, train_loss_history, val_loss_history) in architecture.items():
#     final_loss += [val_loss_history[-1]]
#     loss[branches] = final_loss


plt.figure(figsize = (10,2))
plt.scatter(loss.values(),['One'] * 14 + ['Two'] * 14 + ['Three'] * 14, marker = 'x', c = 'black', label = 'Final Loss (group)')
# plot mean
plt.scatter([np.median(x) for x in (loss.values())], ['One','Two','Three'], c = 'red', label = 'Median Loss')
for median in [np.median(x) for x in (loss.values())]:
  plt.axvline(median, color='red', linestyle='--', lw = 1)
# plt.xlim([0,5])
plt.xlabel('Validation Loss')
plt.ylim([-1,3])
plt.ylabel('Branches')
plt.gca().invert_yaxis()

plt.legend()

plt.savefig(predir+'/figs/model_loss.png', dpi = 800)

# residuals

## one branch

In [None]:
residuals_grouped1 = {}

for group, (model, train_loss_history, val_loss_history) in tqdm(model_grouped1.items()):
  usernames = groups[group]
  data_after_group = data_after.query(f'username in @usernames')
  dataloader_after = data_loader(data_after_group, OneBranchPrepData)
  dataloader_before = data_loader(data_before.query(f'username in @usernames'), OneBranchPrepData)
  model.eval()
  residuals_before = []
  residuals_after = []
  with torch.no_grad():
      for inputs, targets in dataloader_after:
          inputs, targets = inputs.to(device), targets.to(device)
          outputs = model(inputs)
          residuals_after.extend(targets - outputs.squeeze())

  with torch.no_grad():
    for inputs, targets in dataloader_before:
        inputs, targets = inputs.to(device), targets.to(device)
        outputs = model(inputs)
        residuals_before.extend(targets - outputs.squeeze())

  residuals_grouped1[group] = (residuals_before, residuals_after)

In [None]:
all_before1 = []
all_after1 = []
for group, (residuals_before, residuals_after) in residuals_grouped1.items():
  all_before1.extend(residuals_before)
  all_after1.extend(residuals_after)

plt.hist(all_before1, alpha = 0.5, label = 'Before June 2023', density = True, bins = np.arange(-10,10,1))
plt.hist(all_after1, alpha = 0.5, label = 'After February 2024', density = True, bins = np.arange(-10,10,1))
plt.legend()
plt.xlabel('Residuals (log likes)')
# plt.savefig(predir+'/figs/residuals_two_branch.png', dpi = 800)

## two branch

In [None]:
residuals_grouped2 = {}

for group, (model, train_loss_history, val_loss_history) in tqdm(model_grouped2.items()):
  usernames = groups[group]
  data_after_group = data_after.query(f'username in @usernames')
  dataloader_after = data_loader(data_after_group, TwoBranchPrepData)
  dataloader_before = data_loader(data_before.query(f'username in @usernames'), TwoBranchPrepData)
  model.eval()
  residuals_before = []
  residuals_after = []
  with torch.no_grad():
      for inputs, targets in dataloader_after:
          inputs, targets = inputs.to(device), targets.to(device)
          outputs = model(inputs)
          residuals_after.extend(targets - outputs.squeeze())

  with torch.no_grad():
    for inputs, targets in dataloader_before:
        inputs, targets = inputs.to(device), targets.to(device)
        outputs = model(inputs)
        residuals_before.extend(targets - outputs.squeeze())

  residuals_grouped2[group] = (residuals_before, residuals_after)

In [None]:
fig, axes = plt.subplots(3,5, figsize=(10,5), sharey = True, sharex = True)
plt.subplots_adjust(wspace=0, hspace=0)

for i, (ax,(group, (residuals_before, residuals_after))) in enumerate(zip(axes.flat,residuals_grouped2.items())):
  ax.set_xlim([-12, 12])
  # residuals_before = residuals_before_train + residuals_before_val
  ax.hist(residuals_before, alpha = 0.5 ,density=True, label = 'Before Nov 2023 (train+val)', bins = np.arange(-10.5,10.5,1))
  # ax.hist(residuals_before_train, alpha = 0.5 ,density=True, label = 'Before Nov 2023 (train)', bins = np.arange(-10.5,10.5,1))
  # ax.hist(residuals_before_val, alpha = 0.5 ,density=True, label = 'Before Nov 2023 (val)', bins = np.arange(-10.5,10.5,1))
  ax.hist(residuals_after, alpha = 0.5 ,density=True, label = 'After Feb 2024', bins = np.arange(-10.5,10.5,1))
  ax.text(0.5, 0.75, group_titles[group], ha='center', va='center', transform=ax.transAxes)

  # get ymax
  ax.vlines(0, 0, 0.41, linestyles='dashed', colors='black', alpha = 0.25)

  if i == len(residuals_grouped1) - 1:
    # ax.remove()
    ax.legend(loc = 'lower right', bbox_to_anchor=(2.1,0.3))


fig.delaxes(axes[2][4])
fig.supxlabel('Residuals (log likes)')
fig.supylabel('Frequency', x=0.055)
# fig.tight_layout()b

# plt.savefig(predir+'/figs/residuals_large_feb2024.png')

In [None]:
all_before2 = []
all_after2 = []
for group, (residuals_before, residuals_after) in residuals_grouped2.items():
  all_before2.extend(residuals_before)
  all_after2.extend(residuals_after)

plt.hist(all_before2, alpha = 0.5, label = 'Before June 2023', density = True, bins = np.arange(-10,10,1))
plt.hist(all_after2, alpha = 0.5, label = 'After February 2024', density = True, bins = np.arange(-10,10,1))
plt.legend()
plt.xlabel('Residuals (log likes)')
# plt.savefig(predir+'/figs/residuals_two_branch.png', dpi = 800)

## three branch

In [None]:
residuals_grouped3 = {}

for group, (model, train_loss_history, val_loss_history) in tqdm(model_grouped3.items()):
  usernames = groups[group]
  data_after_group = data_after.query(f'username in @usernames')
  if len(data_after_group) < 30:
    continue
  dataloader_after = data_loader(data_after_group, ThreeBranchPrepData)
  dataloader_before = data_loader(data_before.query(f'username in @usernames'), ThreeBranchPrepData)
  model.eval()
  residuals_before = []
  residuals_after = []
  with torch.no_grad():
      for inputs, targets in dataloader_after:
          inputs, targets = inputs.to(device), targets.to(device)
          outputs = model(inputs)
          residuals_after.extend(targets - outputs.squeeze())

  with torch.no_grad():
    for inputs, targets in dataloader_before:
        inputs, targets = inputs.to(device), targets.to(device)
        outputs = model(inputs)
        residuals_before.extend(targets - outputs.squeeze())

  residuals_grouped3[group] = (residuals_before, residuals_after)

In [None]:
all_before3 = []
all_after3 = []
for group, (residuals_before, residuals_after) in residuals_grouped3.items():
  all_before3.extend(residuals_before)
  all_after3.extend(residuals_after)

plt.hist(all_before3, alpha = 0.5, label = 'Before June 2023', density = True, bins = np.arange(-10,10,1))
plt.hist(all_after3, alpha = 0.5, label = 'After February 2024', density = True, bins = np.arange(-10,10,1))
plt.legend()
plt.xlabel('Residuals (log likes)')
# plt.savefig(predir+'/figs/residuals_two_branch.png', dpi = 800)

## agg

In [None]:
png1 = plt.imread(f"{predir}/figs/onebranch.png")
png2 = plt.imread(f"{predir}/figs/twobranch.png")
png3 = plt.imread(f"{predir}/figs/threebranch.png")

In [None]:
plt.figure(figsize = (12,4))
plt.subplots_adjust(wspace=0, hspace=0)

plt.subplot(1,3,1)
plt.imshow(png1)
plt.title('One Branch')
plt.yticks([])
plt.xticks([])
# turn frame off
# plt.axis('off')
plt.gca().set(frame_on=False)
plt.xlabel('Residuals (log likes)')
plt.ylabel('Density')
blue = mpatches.Patch(color='tab:blue', label = 'Before', alpha = 0.5)
orange = mpatches.Patch(color='tab:orange', label = 'After', alpha = 0.5)
plt.legend(handles=[blue, orange], loc = 'upper left', bbox_to_anchor = (0.15,0.98))

plt.subplot(1,3,2)
plt.imshow(png2)
plt.title('Two Branch')
# plt.axis('off')
plt.yticks([])
plt.xticks([])
plt.gca().set(frame_on=False)
plt.xlabel('Residuals (log likes)')

plt.subplot(1,3,3)
plt.imshow(png3)
plt.title('Three Branch')
# plt.axis('off')
plt.yticks([])
plt.xticks([])
plt.gca().set(frame_on=False)
plt.xlabel('Residuals (log likes)')

# import matplotlib.patches as mpatches
# blue = mpatches.Patch(color='tab:blue', label = 'Before', alpha = 0.5)
# orange = mpatches.Patch(color='tab:orange', label = 'After', alpha = 0.5)
# plt.legend(handles=[blue, orange], bbox_to_anchor = (1,1), loc = 'upper left')

plt.savefig(predir+'/figs/residuals_compared_volatile.png', dpi = 800,bbox_inches='tight')

In [None]:
plt.figure(figsize = (12,4))
plt.subplot(1,3,1)
plt.hist(all_before1, alpha = 0.5, label = 'Before', density = True, bins = np.arange(-10,10,1))
plt.hist(all_after1, alpha = 0.5, label = 'After', density = True, bins = np.arange(-10,10,1))
plt.legend(loc='upper left')
plt.xlabel('Residuals (log likes)')
plt.ylabel('Density')
plt.title('One Branch')

plt.subplot(1,3,2)
plt.hist(all_before2, alpha = 0.5, label = 'Before June 2023', density = True, bins = np.arange(-10,10,1))
plt.hist(all_after2, alpha = 0.5, label = 'After February 2024', density = True, bins = np.arange(-10,10,1))
# plt.legend()
plt.xlabel('Residuals (log likes)')
plt.title('Two Branch')

plt.subplot(1,3,3)
plt.hist(all_before3, alpha = 0.5, label = 'Before June 2023', density = True, bins = np.arange(-5,5,1))
plt.hist(all_after3, alpha = 0.5, label = 'After February 2024', density = True, bins = np.arange(-5,5,1))
# plt.legend(bbox_to_anchor = (1,1), loc = 'upper left')
plt.xlabel('Residuals (log likes)')
plt.title('Three Branch')

plt.savefig(predir+'/figs/residuals_compared.png', dpi = 800,bbox_inches='tight')

# split feb 2024

In [None]:
split_before = pd.to_datetime('2023-11-01')
data_before = data[data['date'] < split_before]

split = pd.to_datetime('2024-02-01')
data_after = data[data['date'] >= split]

# for the data after, only include [reachback:] for each account
data_after = data_after.groupby('username').apply(lambda x: x.iloc[REACHBACK:]).reset_index(drop=True)

In [None]:
# two branch

trainval_grouped_feb2024 = {}
for group, usernames in groups.items():
  group_df = data_before.query('username in @usernames')
  print(group, len(group_df), '-'*30)
  if len(group_df) > 100:
    trainval_grouped_feb2024[group] = df_to_traintest(group_df, TwoBranchPrepData)
  else:
    print(f'{group} not enough data')
  # trainval_grouped[group] = df_to_traintest(group_df)

In [None]:
model_grouped_feb2024 = {}
for group, (train_loader, val_loader) in trainval_grouped_feb2024.items():
  print(f"Training {group}")
  model = TwoBranchLikesPredictor(
      sbert_dim=384,
      reachback_length=REACHBACK,
      verobose=False
  )

  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  # device = torch.device('cpu')
  model = model.to(device)
  print(f"Using device: {device}")

  model, train_loss_history, val_loss_history = train_model(model, train_loader, val_loader, lr=0.001, verbose=False, num_epochs=20)

  model_grouped_feb2024[group] = (model, train_loss_history, val_loss_history)

In [None]:
residuals_grouped_feb2024 = {}

for group, (model, train_loss_history, val_loss_history) in tqdm(model_grouped_feb2024.items()):
  usernames = groups[group]
  data_after_group = data_after.query(f'username in @usernames')
  data_before_group = data_before.query(f'username in @usernames')
  dataloader_after = data_loader(data_after_group, TwoBranchPrepData)
  dataloader_before = data_loader(data_before_group, TwoBranchPrepData)
  dataloader_before_train = trainval_grouped_feb2024[group][0]
  dataloader_before_val = trainval_grouped_feb2024[group][1]

  model.eval()
  residuals_before = []
  residuals_before_train = []
  residuals_before_val = []
  residuals_after = []
  with torch.no_grad():
      for inputs, targets in dataloader_after:
          inputs, targets = inputs.to(device), targets.to(device)
          outputs = model(inputs)
          residuals_after.extend(targets - outputs.squeeze())

  with torch.no_grad():
    for inputs, targets in dataloader_before_train:
        inputs, targets = inputs.to(device), targets.to(device)
        outputs = model(inputs)
        residuals_before_train.extend(targets - outputs.squeeze())

  with torch.no_grad():
    for inputs, targets in dataloader_before_val:
        inputs, targets = inputs.to(device), targets.to(device)
        outputs = model(inputs)
        residuals_before_val.extend(targets - outputs.squeeze())

  with torch.no_grad():
    for inputs, targets in dataloader_before:
        inputs, targets = inputs.to(device), targets.to(device)
        outputs = model(inputs)
        residuals_before.extend(targets - outputs.squeeze())

  residuals_grouped_feb2024[group] = (residuals_before_train,residuals_before_val, residuals_before, residuals_after)
  # residuals_grouped_feb2024[group] = (residuals_before, residuals_after)

## plot residuals

In [None]:
fig, axes = plt.subplots(3,5, figsize=(10,5), sharey = True, sharex = True)
plt.subplots_adjust(wspace=0, hspace=0)

for i, (ax,(group, (residuals_before_train,residuals_before_val, residuals_before, residuals_after))) in enumerate(zip(axes.flat,residuals_grouped_feb2024.items())):
  ax.set_xlim([-12, 12])
  ax.hist(residuals_before, alpha = 0.5 ,density=True, label = 'Before Nov 2023', bins = np.arange(-10.5,10.5,1))
  # ax.hist(residuals_before_train, alpha = 0.5 ,density=True, label = 'Before Nov 2023 (train)', bins = np.arange(-10.5,10.5,1))
  # ax.hist(residuals_before_val, alpha = 0.5 ,density=True, label = 'Before Nov 2023 (val)', bins = np.arange(-10.5,10.5,1))
  ax.hist(residuals_after, alpha = 0.5 ,density=True, label = 'After Feb 2024', bins = np.arange(-10.5,10.5,1))
  ax.text(0.5, 0.75, group_titles[group], ha='center', va='center', transform=ax.transAxes)

  # get ymax
  ax.vlines(0, 0, 0.41, linestyles='dashed', colors='black', alpha = 0.25)

  if i == len(residuals_grouped1) - 1:
    # ax.remove()
    ax.legend(loc = 'lower right', bbox_to_anchor=(2.1,0.3))


fig.delaxes(axes[2][4])
fig.supxlabel('Residuals (log likes)')
fig.supylabel('Frequency', x=0.055)
# fig.tight_layout()b

# plt.savefig(predir+'/figs/residuals_large_feb2024.png')

In [None]:
# t-test on averages before and after
from scipy.stats import ttest_ind
from scipy.stats import permutation_test
from datetime import timedelta

df_ttest = pd.DataFrame(columns=['group','tstat','pval','low','high'])

for group, (residuals_before_train,residuals_before_val, residuals_before, residuals_after) in tqdm(residuals_grouped_feb2024.items()):

  # ttest = ttest_ind(before, after, permutations=10_000)
  ttest = ttest_ind(residuals_before, residuals_after)

  interval = ttest.confidence_interval()
  df_ttest = pd.concat([df_ttest, pd.DataFrame({'group': [group], 'tstat': [ttest[0]], 'pval': [ttest[1]], 'low':[interval[0]], 'high':[interval[1]]})])

df_ttest['signif'] = (df_ttest['pval'] * 14) < 0.05 # bonferroni correction

In [None]:
fig, ax1 = plt.subplots(1,1, figsize=(10,5), sharey = True)

plt.subplots_adjust(wspace=0, hspace=0)

ax1.barh(df_ttest['group'][::-1], df_ttest['tstat'][::-1], alpha = 0.5)
# go again for the significant ones
ax1.barh(df_ttest.query('signif')['group'][::-1], df_ttest.query('signif')['tstat'][::-1], color = 'tab:blue', zorder = 2)

# use fill between to color postive values green and negative values red
ax1.fill_betweenx([0,len(groups)-1], -30, 0, color='red', alpha=0.25, zorder = 1)
ax1.fill_betweenx([0,len(groups)-1], 0, 30, color='green', alpha=0.25, zorder = 1)

ax1.grid(zorder=1)
ax1.set_yticks(range(len(group_titles)), [group_titles.get(g, g) for g in groups.keys()][::-1])
ax1.set_ylabel('Group')
ax1.set_xlabel('T-Statistic')

# plt.savefig(predir+'/figs/ttest_large_feb2024.png')

# split jan 2025

In [None]:
split_before = pd.to_datetime('2024-10-01')
data_before = data[data['date'] < split_before]

split = pd.to_datetime('2025-01-06')
data_after = data[data['date'] >= split]

# for the data after, only include [reachback:] for each account
data_after = data_after.groupby('username').apply(lambda x: x.iloc[REACHBACK:]).reset_index(drop=True)

In [None]:
# two branch

trainval_grouped_jan2025 = {}
for group, usernames in groups.items():
  group_df = data_before.query('username in @usernames')
  print(group, len(group_df), '-'*30)
  if len(group_df) > 100:
    trainval_grouped_jan2025[group] = df_to_traintest(group_df, TwoBranchPrepData)
  else:
    print(f'{group} not enough data')
  # trainval_grouped[group] = df_to_traintest(group_df)

In [None]:
model_grouped_jan2025= {}
for group, (train_loader, val_loader) in trainval_grouped_jan2025.items():
  print(f"Training {group}")
  model = TwoBranchLikesPredictor(
      sbert_dim=384,
      reachback_length=REACHBACK,
      verobose=False
  )

  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  # device = torch.device('cpu')
  model = model.to(device)
  print(f"Using device: {device}")

  model, train_loss_history, val_loss_history = train_model(model, train_loader, val_loader, lr=0.001, verbose=False, num_epochs=20)

  model_grouped_jan2025[group] = (model, train_loss_history, val_loss_history)

In [None]:
list(model_grouped_jan2025.keys())

In [None]:
residuals_grouped_jan2025.keys()

In [None]:
residuals_grouped_jan2025 = {}

for group, (model, train_loss_history, val_loss_history) in tqdm(model_grouped_jan2025.items()):
  usernames = groups[group]
  data_after_group = data_after.query(f'username in @usernames')
  data_before_group = data_before.query(f'username in @usernames')
  if len(data_after_group) < 30:
    print(f'{group} not enough after data')
    continue
  if len(data_before_group) < 30:
    print(f'{group} not enough before data')
    continue
  dataloader_after = data_loader(data_after_group, TwoBranchPrepData)
  dataloader_before = data_loader(data_before_group, TwoBranchPrepData)
  # dataloader_before_train = trainval_grouped_feb2024[group][0]
  # dataloader_before_val = trainval_grouped_feb2024[group][1]

  model.eval()
  residuals_before = []
  residuals_before_train = []
  residuals_before_val = []
  residuals_after = []
  with torch.no_grad():
      for inputs, targets in dataloader_after:
          inputs, targets = inputs.to(device), targets.to(device)
          outputs = model(inputs)
          residuals_after.extend(targets - outputs.squeeze())

  # with torch.no_grad():
  #   for inputs, targets in dataloader_before_train:
  #       inputs, targets = inputs.to(device), targets.to(device)
  #       outputs = model(inputs)
  #       residuals_before_train.extend(targets - outputs.squeeze())

  # with torch.no_grad():
  #   for inputs, targets in dataloader_before_val:
  #       inputs, targets = inputs.to(device), targets.to(device)
  #       outputs = model(inputs)
  #       residuals_before_val.extend(targets - outputs.squeeze())

  with torch.no_grad():
    for inputs, targets in dataloader_before:
        inputs, targets = inputs.to(device), targets.to(device)
        outputs = model(inputs)
        residuals_before.extend(targets - outputs.squeeze())

  # residuals_grouped_feb2024[group] = (residuals_before_train,residuals_before_val, residuals_before, residuals_after)
  residuals_grouped_jan2025[group] = (residuals_before, residuals_after)

## plot residuals

In [None]:
fig, axes = plt.subplots(3,5, figsize=(10,5), sharey = True, sharex = True)
plt.subplots_adjust(wspace=0, hspace=0)

for i, (ax,(group)) in enumerate(zip(axes.flat,group_order)):
  ax.set_xlim([-12, 12])
  if group in residuals_grouped_jan2025.keys():
    residuals_before, residuals_after = residuals_grouped_jan2025[group]
    ax.hist(residuals_before, alpha = 0.5 ,density=True, label = 'Before Nov 2023', bins = np.arange(-10.5,10.5,1))
    # ax.hist(residuals_before_train, alpha = 0.5 ,density=True, label = 'Before Nov 2023 (train)', bins = np.arange(-10.5,10.5,1))
    # ax.hist(residuals_before_val, alpha = 0.5 ,density=True, label = 'Before Nov 2023 (val)', bins = np.arange(-10.5,10.5,1))
    ax.hist(residuals_after, alpha = 0.5 ,density=True, label = 'After Feb 2024', bins = np.arange(-10.5,10.5,1))
    ax.vlines(0, 0, 0.41, linestyles='dashed', colors='black', alpha = 0.25)
  else:
    ax.text(0.5, 0.5, "not enough data", ha='center', va='center', transform=ax.transAxes, color = 'darkgrey')
  ax.text(0.5, 0.75, group_titles[group], ha='center', va='center', transform=ax.transAxes)
  # get ymax

  if i == len(group_order) - 1:
    # ax.remove()
    ax.legend(loc = 'lower right', bbox_to_anchor=(2.1,0.3))


fig.delaxes(axes[2][4])
fig.supxlabel('Residuals (log likes)')
fig.supylabel('Frequency', x=0.055)
# fig.tight_layout()b

# plt.savefig(predir+'/figs/residuals_large_feb2024.png')

In [None]:
# t-test on averages before and after
from scipy.stats import ttest_ind
from scipy.stats import permutation_test
from datetime import timedelta

df_ttest = pd.DataFrame(columns=['group','tstat','pval','low','high'])

for group, (residuals_before, residuals_after) in tqdm(residuals_grouped_jan2025.items()):

  # ttest = ttest_ind(before, after, permutations=10_000)
  ttest = ttest_ind(residuals_before, residuals_after)

  interval = ttest.confidence_interval()
  df_ttest = pd.concat([df_ttest, pd.DataFrame({'group': [group], 'tstat': [ttest[0]], 'pval': [ttest[1]], 'low':[interval[0]], 'high':[interval[1]]})])

df_ttest['signif'] = (df_ttest['pval'] * 14) < 0.05 # bonferroni correction

In [None]:
fig, ax1 = plt.subplots(1,1, figsize=(10,5), sharey = True)

plt.subplots_adjust(wspace=0, hspace=0)

ax1.barh(df_ttest['group'][::-1], df_ttest['tstat'][::-1], alpha = 0.5)
# go again for the significant ones
ax1.barh(df_ttest.query('signif')['group'][::-1], df_ttest.query('signif')['tstat'][::-1], color = 'tab:blue', zorder = 2)

# use fill between to color postive values green and negative values red
ax1.fill_betweenx([0,len(groups)-1], -30, 0, color='red', alpha=0.25, zorder = 1)
ax1.fill_betweenx([0,len(groups)-1], 0, 30, color='green', alpha=0.25, zorder = 1)

ax1.grid(zorder=1)
ax1.set_yticks(range(len(group_titles)), [group_titles.get(g, g) for g in groups.keys()][::-1])
ax1.set_ylabel('Group')
ax1.set_xlabel('T-Statistic')

# plt.savefig(predir+'/figs/ttest_large_feb2024.png')

# split no meaning


In [None]:
split_before = pd.to_datetime('2024-07-01') # also try 6, 5, 4
data_before = data[data['date'] < split_before]
data_before = data_before[data_before['date'] > pd.to_datetime('2024-02-01')]

split = pd.to_datetime('2024-08-01')
data_after = data[data['date'] >= split]
data_after = data_after[data_after['date'] <= pd.to_datetime('2024-11-01')]

# for the data after, only include [reachback:] for each account
data_after = data_after.groupby('username').apply(lambda x: x.iloc[REACHBACK:]).reset_index(drop=True)

In [None]:
# two branch

trainval_grouped_nomeaning = {}
for group, usernames in groups.items():
  group_df = data_before.query('username in @usernames')
  print(group, len(group_df), '-'*30)
  if len(group_df) > 100:
    trainval_grouped_nomeaning[group] = df_to_traintest(group_df, TwoBranchPrepData)
  else:
    print(f'{group} not enough data')
  # trainval_grouped[group] = df_to_traintest(group_df)

In [None]:
model_grouped_nomeaning= {}
for group, (train_loader, val_loader) in trainval_grouped_nomeaning.items():
  print(f"Training {group}")
  model = TwoBranchLikesPredictor(
      sbert_dim=384,
      reachback_length=REACHBACK,
      verobose=False
  )

  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  # device = torch.device('cpu')
  model = model.to(device)
  print(f"Using device: {device}")

  model, train_loss_history, val_loss_history = train_model(model, train_loader, val_loader, lr=0.001, verbose=False, num_epochs=20)

  model_grouped_nomeaning[group] = (model, train_loss_history, val_loss_history)

In [None]:
for i,group in enumerate(model_grouped_nomeaning.keys()):
  model, train_loss_history, val_loss_history = model_grouped_nomeaning[group]
  plt.subplot(3,5,i+1)
  plt.text(0.5, 0.5, f'{group}:\n{val_loss_history[-1]:.2f}', ha='center', va='center', transform=plt.gca().transAxes)
  plt.plot(train_loss_history, label='Train Loss', )
  plt.scatter(range(len(train_loss_history)), train_loss_history, color='tab:blue', s=5)
  plt.plot(val_loss_history, label='Validation Loss')
  plt.scatter(range(len(val_loss_history)), val_loss_history, color='tab:orange', s= 5)
  if i % 5 != 0:
    plt.yticks([])
  else:
    plt.ylabel('Loss')
  plt.xlabel('Epoch')

plt.legend(loc = 'lower right', bbox_to_anchor=(2.1,0.3))
plt.show()

In [None]:
residuals_grouped_nomeaning = {}

for group, (model, train_loss_history, val_loss_history) in tqdm(model_grouped_nomeaning.items()):
  usernames = groups[group]
  data_after_group = data_after.query(f'username in @usernames')
  data_before_group = data_before.query(f'username in @usernames')
  if len(data_after_group) < 30:
    print(f'{group} not enough after data')
    continue
  if len(data_before_group) < 30:
    print(f'{group} not enough before data')
    continue
  dataloader_after = data_loader(data_after_group, TwoBranchPrepData)
  dataloader_before = data_loader(data_before_group, TwoBranchPrepData)
  dataloader_before_train = trainval_grouped_nomeaning[group][0]
  dataloader_before_val = trainval_grouped_nomeaning[group][1]

  model.eval()
  residuals_before = []
  residuals_before_train = []
  residuals_before_val = []
  residuals_after = []
  with torch.no_grad():
      for inputs, targets in dataloader_after:
          inputs, targets = inputs.to(device), targets.to(device)
          outputs = model(inputs)
          residuals_after.extend(targets - outputs.squeeze())

  # with torch.no_grad():
  #   for inputs, targets in dataloader_before_train:
  #       inputs, targets = inputs.to(device), targets.to(device)
  #       outputs = model(inputs)
  #       residuals_before_train.extend(targets - outputs.squeeze())

  # with torch.no_grad():
  #   for inputs, targets in dataloader_before_val:
  #       inputs, targets = inputs.to(device), targets.to(device)
  #       outputs = model(inputs)
  #       residuals_before_val.extend(targets - outputs.squeeze())

  with torch.no_grad():
    for inputs, targets in dataloader_before:
        inputs, targets = inputs.to(device), targets.to(device)
        outputs = model(inputs)
        residuals_before.extend(targets - outputs.squeeze())

  residuals_grouped_nomeaning[group] = (residuals_before_train,residuals_before_val, residuals_before, residuals_after)
  # residuals_grouped_nomeaning[group] = (residuals_before, residuals_after)

In [None]:
# with open(predir+'/residuals_grouped_nomeaning.pkl', 'wb') as f:
#   pickle.dump(residuals_grouped_nomeaning, f)

with open(predir + '/residuals_grouped_nomeaning.pkl', 'rb') as f:
  residuals_grouped_nomeaning = pickle.load(f)

In [None]:
len(residuals_before_val)

In [None]:
all_before = []
all_after = []
for group, (residuals_before_train,residuals_before_val, residuals_before, residuals_after) in residuals_grouped_nomeaning.items():
  all_before.extend(residuals_before)
  all_after.extend(residuals_after)

plt.figure(figsize=(10,5))
plt.subplot(1,2,1)
plt.hist(all_before, alpha = 0.5 ,density=True, label = 'Before\n(All)', bins = np.arange(-10.5,10.5,1))
plt.hist(all_after, alpha = 0.5 ,density=True, label = 'After', bins = np.arange(-10.5,10.5,1))
plt.xlabel('Residuals (log likes)')
plt.ylabel('Density')
plt.legend()

plt.subplot(1,2,2)
plt.hist(residuals_before_val, alpha = 0.5 ,density=True, label = 'Before\n(Validation)', bins = np.arange(-10.5,10.5,1), color = 'tab:green')
plt.hist(all_after, alpha = 0.5 ,density=True, label = 'After', bins = np.arange(-10.5,10.5,1), color = 'tab:orange')
plt.xlabel('Residuals (log likes)')
plt.legend()

# plt.savefig(predir+'/figs/residuals_val_vs_all.png', dpi = 800)

## plot residuals

In [None]:
with open(predir+'/residuals_grouped_nomeaning.pkl', 'rb') as f:
  residuals_grouped_nomeaning = pickle.load(f)

In [None]:
# t-test on averages before and after
# from scipy.stats import ttest_ind
# from scipy.stats import permutation_test
# from datetime import timedelta

# df_ttest = pd.DataFrame(columns=['group','tstat','pval','low','high'])

# for group, (residuals_before_train,residuals_before_val, residuals_before, residuals_after) in tqdm(residuals_grouped_nomeaning.items()):
#   residuals_before = [x.cpu().numpy() for x in residuals_before]
#   residuals_after = [x.cpu().numpy() for x in residuals_after]

#   # ttest = ttest_ind(before, after, permutations=10_000)
#   ttest = ttest_ind(residuals_before, residuals_after)

#   interval = ttest.confidence_interval()
#   df_ttest = pd.concat([df_ttest, pd.DataFrame({'group': [group], 'tstat': [ttest[0]], 'pval': [ttest[1]], 'low':[interval[0]], 'high':[interval[1]]})])

# df_ttest['signif'] = (df_ttest['pval'] * 14) < 0.05 # bonferroni correction

# t-test on averages before and after
from scipy.stats import ttest_ind
from scipy.stats import permutation_test
from datetime import timedelta

df_ttest = pd.DataFrame(columns=['group','tstat','pval','residual'])

for group, (residuals_before_train,residuals_before_val, residuals_before, residuals_after) in tqdm(residuals_grouped_nomeaning.items()):

    residuals_before = [x.cpu().numpy() for x in residuals_before]
    residuals_after = [x.cpu().numpy() for x in residuals_after]

    ttest = ttest_ind(residuals_before, residuals_after)
    # interval = ttest.confidence_interval()
    df_ttest = pd.concat([df_ttest, pd.DataFrame({'group': [group], 'tstat': [ttest[0]], 'pval': [ttest[1]], 'residual':True})])

    before_ll = data_before.query('group == @group')['log_likes']
    after_ll = data_after.query('group == @group')['log_likes']

    ttest = ttest_ind(before_ll, after_ll)

    df_ttest = pd.concat([df_ttest, pd.DataFrame({'group': [group], 'tstat': [ttest[0]], 'pval': [ttest[1]], 'residual':False})])

df_ttest['signif'] = (df_ttest['pval'] * 14) < 0.05 # bonferroni correction

df_ttest = pd.concat([df_ttest, pd.DataFrame({'group': ['healthleft','healthleft','queer','queer'], 'residual':[True,False,True,False]})])

# # order df_ttest by group_order
# df_ttest = df_ttest.set_index('group')
# # fill in missing groups
# df_ttest = df_ttest.reindex(group_order)
# df_ttest = df_ttest.loc[group_order].reset_index()
df_ttest['tstat'].fillna(0, inplace=True)
df_ttest['signif'].fillna(False, inplace=True)
df_ttest['signif_values'] = np.where(df_ttest['signif'], df_ttest['tstat'], 0)

In [None]:
fig = plt.figure(figsize=(10,4))
subfigs = fig.subfigures(1,2, width_ratios = [1,3], wspace= -0.1)

axL = subfigs[1].subplots(3,5, sharey = True, sharex = True)
plt.subplots_adjust(wspace=0, hspace=0)
for i, (ax,(group)) in enumerate(zip(axL.flat,group_order)):
  ax.set_xlim([-12, 12])
  if group in residuals_grouped_nomeaning.keys():
    residuals_before_train,residuals_before_val, residuals_before, residuals_after = residuals_grouped_nomeaning[group]
    residuals_before = [x.cpu().numpy() for x in residuals_before]
    residuals_after = [x.cpu().numpy() for x in residuals_after]
    ax.hist(residuals_before, alpha = 0.5 ,density=True, label = 'Before', bins = np.arange(-10.5,10.5,1))
    ax.hist(residuals_after, alpha = 0.5 ,density=True, label = 'After', bins = np.arange(-10.5,10.5,1))
    ax.vlines(0, 0, 0.41, colors='black', alpha = 0.25, linestyles="dashed")
  else:
    ax.text(0.5, 0.5, "not enough data", ha='center', va='center', transform=ax.transAxes, color = 'darkgrey')
  ax.text(0.5, 0.75, group_titles[group], ha='center', va='center', transform=ax.transAxes)
  # get ymax
  ax.set_ylim([0,0.47])

  if i == len(group_order) - 2:
    ax.legend(loc = 'lower right', bbox_to_anchor=(2.9,0.3))
  if i % 5 != 0:
    ax.yaxis.set_visible(False)
  if i == 9:
    ax.xaxis.set_visible(False)

  # if i%5 != 0:
  #   ax.set_yticks([])
  # else:
  #   ax.set_yticks([0,0.2,0.4])

# delete the last ax
axL.flat[-1].remove()
subfigs[1].supxlabel('Residuals (log likes)', y=-0.02)
subfigs[1].supylabel('Density', x=0.055)

#######################################
axR = subfigs[0].subplots(1,1)
axR.barh(df_ttest.query('residual == True')['group'][::-1], df_ttest.query('residual == True')['tstat'][::-1], alpha = 0.5)
axR.barh(df_ttest.query('signif ==True and residual == True')['group'][::-1], df_ttest.query('signif ==True and residual == True')['tstat'][::-1], color = 'tab:blue', zorder = 2)

axR.hlines(range(len(group_order)), xmin=0, xmax=df_ttest.query('residual == False')['tstat'], color='black', zorder = 4, alpha = 0.5, linestyles='dashed')
# axR.hlines(, xmin=0, xmax=df_ttest.query('residual'), color='black', zorder = 4, alpha = 0.5, linestyles='dashed')
axR.hlines(range(len(group_order)), xmin=0, xmax=df_ttest.query('residual == False')['signif_values'], color='black', zorder = 4)


# use fill between to color postive values green and negative values red
axR.fill_betweenx([-100,100], -100, 0, color='red', alpha=0.15, zorder = 1)
axR.fill_betweenx([-100,100], 0, 100, color='green', alpha=0.15, zorder = 1)

# set ylim xlim
axR.set_ylim([-1,len(group_order)])
axR.set_xlim([-25,25])

axR.grid(zorder=1, alpha = 0.5,  linestyle='--')
axR.set_yticks(range(len(group_titles)), [group_titles.get(g, g) for g in group_order][::-1])
axR.set_xticks([-20,-10,0,10,20])

# signif_patch = mpatches.Patch(color='tab:blue', alpha = 1, label = 'Significant')
# insignif_patch = mpatches.Patch(color='tab:blue', alpha = 0.5, label = 'Insignificant')
# axR.legend(handles=[signif_patch, insignif_patch], loc = 'lower left')

subfigs[0].supxlabel('T-Statistic', y = 0)
subfigs[0].supylabel('Group', x = -0.32)

# plt.savefig(predir+'/figs/ttest_control.png',bbox_inches='tight', dpi = 800)