In [1]:
import pandas as pd
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import torch
from tqdm import tqdm

# Load the parquet file
df = pd.read_parquet("data/bluesky_text_embeddings (1).parquet")

# Unpack the binary embeddings
def unpack_embeddings(packed_bytes):
    return np.unpackbits(np.frombuffer(packed_bytes, dtype=np.uint8))

# Apply unpacking to get original binary embeddings
df['embeddings'] = df['embeddings'].apply(unpack_embeddings)

# Now you can look at the first few rows to verify
print("First embedding shape:", len(df['embeddings'].iloc[0]))
print(df[['item_id', 'embeddings']].head())

First embedding shape: 128
   item_id                                         embeddings
0  3460233  [1, 0, 1, 1, 1, 0, 0, 1, 1, 0, 0, 1, 0, 0, 0, ...
1  3044498  [1, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 0, 0, 0, ...
2  1582998  [1, 0, 0, 1, 1, 0, 0, 1, 0, 0, 1, 1, 1, 0, 0, ...
3  5436174  [1, 0, 1, 1, 1, 0, 1, 1, 0, 0, 0, 1, 0, 0, 0, ...
4  1582999  [1, 0, 0, 1, 1, 0, 0, 1, 0, 0, 1, 1, 1, 0, 0, ...


In [2]:
import pandas as pd
import numpy as np
import duckdb

# 1. Load and unpack embeddings
con = duckdb.connect()
con.execute("""
    CREATE TABLE embeddings AS SELECT * FROM read_parquet('data/bluesky_text_embeddings (1).parquet');
""")
post_embeddings_df = con.execute("SELECT * FROM embeddings").fetchdf()

def unpack_embeddings(packed_bytes):
    return np.unpackbits(np.frombuffer(packed_bytes, dtype=np.uint8))

post_embeddings_df['embeddings'] = post_embeddings_df['embeddings'].apply(unpack_embeddings)

# 2. Load interactions
con.execute("""
    CREATE TABLE interactions AS SELECT * FROM read_csv('data/bluesky.csv');
""")
interactions_df = con.execute("SELECT * FROM interactions").fetchdf()

# 3. Join interactions with post embeddings
joined_df = interactions_df.merge(
    post_embeddings_df,
    left_on='destination_node',
    right_on='item_id',
    how='inner'
)

# 4. Group by user and create user embeddings
user_embeddings = joined_df.groupby('source_node')['embeddings'].agg(
    lambda x: np.mean(list(x), axis=0)
).reset_index()
user_embeddings.columns = ['user_id', 'user_embedding']

# 5. Create final DataFrame with all information
final_df = joined_df.merge(
    user_embeddings,
    left_on='source_node',
    right_on='user_id',
    how='inner'
)

# Verify the data
print("Final DataFrame shape:", final_df.shape)
print("\nColumns:", final_df.columns.tolist())
print("\nSample user-post pair:")
sample = final_df.iloc[0]
print(f"User ID: {sample['source_node']}")
print(f"Post ID: {sample['destination_node']}")
print(f"User embedding (first 10):", sample['user_embedding'][:10])
print(f"Post embedding (first 10):", sample['embeddings'][:10])

Final DataFrame shape: (16943200, 8)

Columns: ['source_node', 'destination_node', 'timestamp', 'edge_label', 'item_id', 'embeddings', 'user_id', 'user_embedding']

Sample user-post pair:
User ID: 50947
Post ID: 3460233
User embedding (first 10): [0.96829477 0.         0.86032562 0.48671808 0.95886889 0.04798629
 0.15595544 0.9991431  0.46615253 0.48500428]
Post embedding (first 10): [1 0 1 1 1 0 0 1 1 0]


In [3]:
# Pick a random user_id to examine
user_id = final_df['user_id'].iloc[21]

# Get their embedding
user_emb = final_df[final_df['user_id'] == user_id]['user_embedding'].iloc[0]

# Get all posts this user liked
liked_posts = final_df[final_df['user_id'] == user_id]['destination_node'].values

# Get embeddings for these posts
liked_post_embeddings = post_embeddings_df[post_embeddings_df['item_id'].isin(liked_posts)]['embeddings'].values

print(f"User {user_id}:")
print(f"Number of liked posts: {len(liked_posts)}")
print(f"\nUser embedding (first 20 values):\n{user_emb[:20]}")
print(f"\nLiked posts embeddings (first 3 posts, first 20 values):")
for i, emb in enumerate(liked_post_embeddings[:9]):
    print(f"Post {i}: {emb[:20]}")

# Verify that user embedding is indeed the average
avg_liked_embeddings = np.mean(liked_post_embeddings, axis=0)
print(f"\nVerification - are user embeddings the average of liked posts?")
print(f"Max difference: {np.max(np.abs(user_emb - avg_liked_embeddings))}")

User 35444:
Number of liked posts: 299

User embedding (first 20 values):
[0.95317726 0.         0.85284281 0.56856187 0.9632107  0.05685619
 0.13043478 0.98996656 0.4548495  0.55518395 0.34448161 1.
 0.33110368 0.11371237 0.05016722 0.33779264 0.03344482 0.94983278
 0.55518395 0.5819398 ]

Liked posts embeddings (first 3 posts, first 20 values):
Post 0: [1 0 0 1 1 0 0 1 1 0 0 1 0 0 0 0 0 1 1 0]
Post 1: [1 0 0 1 1 0 0 1 0 0 1 1 1 0 0 0 0 1 1 1]
Post 2: [1 0 0 1 1 0 0 1 1 1 0 1 0 0 0 1 0 1 1 1]
Post 3: [1 0 1 1 1 0 0 1 0 0 0 1 0 1 0 0 0 1 0 1]
Post 4: [1 0 1 1 1 0 0 1 0 0 0 1 0 0 0 1 0 1 0 0]
Post 5: [1 0 1 0 0 0 0 1 1 1 0 1 1 1 0 0 0 1 1 0]
Post 6: [1 0 1 0 1 0 0 1 0 0 1 1 0 0 0 0 0 1 1 0]
Post 7: [1 0 1 1 1 0 0 1 0 0 1 1 0 0 0 1 0 1 0 1]
Post 8: [1 0 1 1 1 0 0 1 0 1 0 1 1 0 0 0 0 1 1 0]

Verification - are user embeddings the average of liked posts?
Max difference: 1.0


In [4]:
import torch
from torch.utils.data import Dataset, DataLoader
from model import UserTower, PostTower, TwoTowerModel
from sklearn.model_selection import train_test_split
import torch.nn as nn

class UserPostDataset(Dataset):
    def __init__(self, df, negative_samples=1):
        self.df = df
        self.negative_samples = negative_samples
        # Precompute user's positive posts for faster lookup
        self.user_positives = {
            user: set(group['destination_node'].values) 
            for user, group in df.groupby('source_node')
        }
        self.all_posts = df['destination_node'].unique()
        # Calculate total length including negative samples
        self.length = len(df) * (self.negative_samples + 1)
        
    def __len__(self):
        return self.length  # Return integer length
    
    def __getitem__(self, idx):
        interaction_idx = idx // (self.negative_samples + 1)
        is_positive = idx % (self.negative_samples + 1) == 0
        
        row = self.df.iloc[interaction_idx]
        user_id = row['source_node']
        
        if is_positive:
            post_emb = row['embeddings']
            user_emb = row['user_embedding']
        else:
            # Simple random sampling without checking
            neg_post_idx = np.random.choice(len(self.df))
            neg_post = self.df.iloc[neg_post_idx]
            post_emb = neg_post['embeddings']
            user_emb = row['user_embedding']
        
        return (
            torch.tensor(user_emb, dtype=torch.float32),
            torch.tensor(post_emb, dtype=torch.float32),
            torch.tensor(1.0 if is_positive else 0.0, dtype=torch.float32)
        )

# Create datasets and dataloaders
train_df, val_df = train_test_split(final_df, test_size=0.2, random_state=42)

train_dataset = UserPostDataset(train_df)
val_dataset = UserPostDataset(val_df)

train_loader = DataLoader(train_dataset, batch_size=1024, shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=1024, shuffle=False, num_workers=4)

In [5]:
# # 2. Split data and create dataloaders
# train_df, val_df = train_test_split(final_df, test_size=0.2, random_state=42)
# all_post_ids = final_df['destination_node'].unique()

# train_dataset = UserPostDataset(train_df, all_post_ids)
# val_dataset = UserPostDataset(val_df, all_post_ids)

# train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True, num_workers=4)
# val_loader = DataLoader(val_dataset, batch_size=128, shuffle=False, num_workers=4)

# # 3. Initialize model
# embedding_dim = len(final_df['embeddings'].iloc[0])  # Should be 128
# hidden_dims = [64, 32]

# user_tower = UserTower(embedding_dim=embedding_dim, hidden_dims=hidden_dims)
# post_tower = PostTower(embedding_dim=embedding_dim, hidden_dims=hidden_dims)
# model = TwoTowerModel(user_tower, post_tower)

# # 4. Training setup
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# print(f"Using device: {device}")

# model = model.to(device)
# optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
# criterion = nn.BCEWithLogitsLoss()

# 5. Training loop
# num_epochs = 3
# best_val_loss = float('inf')

# for epoch in range(num_epochs):
#     # Training phase
#     model.train()
#     train_loss = 0
#     for batch_idx, (user_features, post_features, labels) in enumerate(train_loader):
#         user_features = user_features.to(device)
#         post_features = post_features.to(device)
#         labels = labels.to(device)
        
#         optimizer.zero_grad()
#         user_emb, post_emb = model(user_features, post_features)
        
#         # Compute similarity scores
#         scores = torch.sum(user_emb * post_emb, dim=1)
#         loss = criterion(scores, labels)
        
#         loss.backward()
#         optimizer.step()
        
#         train_loss += loss.item()
        
#         if batch_idx % 100 == 0:
#             print(f'Epoch {epoch}, Batch {batch_idx}, Loss: {loss.item():.4f}')
    
#     avg_train_loss = train_loss / len(train_loader)
    
#     # Validation phase
#     model.eval()
#     val_loss = 0
#     correct_predictions = 0
#     total_predictions = 0
    
#     with torch.no_grad():
#         for user_features, post_features, labels in val_loader:
#             user_features = user_features.to(device)
#             post_features = post_features.to(device)
#             labels = labels.to(device)
            
#             user_emb, post_emb = model(user_features, post_features)
#             scores = torch.sum(user_emb * post_emb, dim=1)
#             loss = criterion(scores, labels)
            
#             val_loss += loss.item()
            
#             # Calculate accuracy
#             predictions = (torch.sigmoid(scores) > 0.5).float()
#             correct_predictions += (predictions == labels).sum().item()
#             total_predictions += labels.size(0)
    
#     avg_val_loss = val_loss / len(val_loader)
#     accuracy = correct_predictions / total_predictions
    
#     print(f'Epoch {epoch}:')
#     print(f'  Training Loss: {avg_train_loss:.4f}')
#     print(f'  Validation Loss: {avg_val_loss:.4f}')
#     print(f'  Validation Accuracy: {accuracy:.4f}')
    
#     # Save best model
#     if avg_val_loss < best_val_loss:
#         best_val_loss = avg_val_loss
#         torch.save(model.state_dict(), 'best_model.pt')

In [6]:
embedding_dim=128

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [7]:
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm

# Modified model with simpler architecture for testing
class SimpleTwoTowerModel(nn.Module):
    def __init__(self, input_dim):
        super().__init__()
        self.user_tower = nn.Sequential(
            nn.Linear(input_dim, 64),
            nn.ReLU(),
            nn.Linear(64, 32),
        )
        self.post_tower = nn.Sequential(
            nn.Linear(input_dim, 64),
            nn.ReLU(),
            nn.Linear(64, 32),
        )
        
    def forward(self, user_features, post_features):
        user_emb = F.normalize(self.user_tower(user_features), p=2, dim=1)
        post_emb = F.normalize(self.post_tower(post_features), p=2, dim=1)
        return user_emb, post_emb

# Try training with this simpler model
model = SimpleTwoTowerModel(embedding_dim).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = nn.BCEWithLogitsLoss()

# train_loader = DataLoader(train_dataset, batch_size=1024, shuffle=True, num_workers=4)
# val_loader = DataLoader(val_dataset, batch_size=1024, shuffle=False, num_workers=4)

# Training loop with gradient norm monitoring
for batch_idx, (user_features, post_features, labels) in enumerate(tqdm(train_loader)):
    user_features = user_features.to(device)
    post_features = post_features.to(device)
    labels = labels.to(device)
    
    optimizer.zero_grad()
    user_emb, post_emb = model(user_features, post_features)
    scores = torch.sum(user_emb * post_emb, dim=1)
    loss = criterion(scores, labels)
    
    loss.backward()
    
    # Monitor gradients
    total_norm = 0
    for p in model.parameters():
        if p.grad is not None:
            param_norm = p.grad.data.norm(2)
            total_norm += param_norm.item() ** 2
    total_norm = total_norm ** 0.5
    
    optimizer.step()
    
    if batch_idx % 1000 == 0:
        print(f"Batch {batch_idx} statistics:")
        print(f"  Loss: {loss.item():.4f}")
        print(f"  Gradient norm: {total_norm:.4f}")
        print(f"  Score range: [{scores.min().item():.4f}, {scores.max().item():.4f}]")
        print(f"  Prediction mean: {torch.sigmoid(scores).mean().item():.4f}")

  0%|          | 1/26474 [00:02<17:32:25,  2.39s/it]

Batch 0 statistics:
  Loss: 0.7016
  Gradient norm: 0.3347
  Score range: [-0.0336, 0.4606]
  Prediction mean: 0.5540


  0%|          | 105/26474 [00:06<17:52, 24.58it/s] 

Batch 100 statistics:
  Loss: 0.6425
  Gradient norm: 0.2732
  Score range: [-0.9322, 0.9787]
  Prediction mean: 0.5250


  1%|          | 203/26474 [00:10<18:57, 23.09it/s]

Batch 200 statistics:
  Loss: 0.6208
  Gradient norm: 0.1968
  Score range: [-0.9408, 0.9831]
  Prediction mean: 0.5263


  1%|          | 302/26474 [00:14<16:11, 26.94it/s]

Batch 300 statistics:
  Loss: 0.6329
  Gradient norm: 0.1626
  Score range: [-0.9465, 0.9658]
  Prediction mean: 0.5216


  2%|▏         | 405/26474 [00:20<19:47, 21.96it/s]

Batch 400 statistics:
  Loss: 0.6214
  Gradient norm: 0.1214
  Score range: [-0.9389, 0.9417]
  Prediction mean: 0.5351


  2%|▏         | 504/26474 [00:24<15:09, 28.55it/s]

Batch 500 statistics:
  Loss: 0.6231
  Gradient norm: 0.1655
  Score range: [-0.9722, 0.9068]
  Prediction mean: 0.5223


  2%|▏         | 604/26474 [00:28<16:37, 25.93it/s]

Batch 600 statistics:
  Loss: 0.6253
  Gradient norm: 0.2467
  Score range: [-0.9565, 0.9417]
  Prediction mean: 0.5328


  3%|▎         | 704/26474 [00:31<15:58, 26.89it/s]

Batch 700 statistics:
  Loss: 0.6255
  Gradient norm: 0.1783
  Score range: [-0.9539, 0.9327]
  Prediction mean: 0.5311


  3%|▎         | 800/26474 [00:35<15:27, 27.67it/s]

Batch 800 statistics:
  Loss: 0.6208
  Gradient norm: 0.2681
  Score range: [-0.9482, 0.9709]
  Prediction mean: 0.5380


  3%|▎         | 904/26474 [00:40<15:59, 26.64it/s]

Batch 900 statistics:
  Loss: 0.6235
  Gradient norm: 0.1917
  Score range: [-0.9608, 0.9313]
  Prediction mean: 0.5224


  4%|▍         | 1004/26474 [00:44<15:11, 27.93it/s]

Batch 1000 statistics:
  Loss: 0.6174
  Gradient norm: 0.1815
  Score range: [-0.9564, 0.9387]
  Prediction mean: 0.5342


  4%|▍         | 1104/26474 [00:48<15:27, 27.34it/s]

Batch 1100 statistics:
  Loss: 0.6143
  Gradient norm: 0.1605
  Score range: [-0.9444, 0.9658]
  Prediction mean: 0.5346


  5%|▍         | 1204/26474 [00:51<15:25, 27.30it/s]

Batch 1200 statistics:
  Loss: 0.6212
  Gradient norm: 0.1902
  Score range: [-0.9542, 0.9573]
  Prediction mean: 0.5362


  5%|▍         | 1303/26474 [00:56<17:51, 23.49it/s]

Batch 1300 statistics:
  Loss: 0.6218
  Gradient norm: 0.1377
  Score range: [-0.9504, 0.9365]
  Prediction mean: 0.5294


  5%|▌         | 1405/26474 [01:00<15:30, 26.95it/s]

Batch 1400 statistics:
  Loss: 0.6203
  Gradient norm: 0.1513
  Score range: [-0.9668, 0.8909]
  Prediction mean: 0.5312


  6%|▌         | 1505/26474 [01:04<14:45, 28.20it/s]

Batch 1500 statistics:
  Loss: 0.6221
  Gradient norm: 0.1652
  Score range: [-0.9579, 0.9359]
  Prediction mean: 0.5303


  6%|▌         | 1605/26474 [01:07<15:47, 26.25it/s]

Batch 1600 statistics:
  Loss: 0.6274
  Gradient norm: 0.1423
  Score range: [-0.9553, 0.9410]
  Prediction mean: 0.5177


  6%|▋         | 1701/26474 [01:11<16:07, 25.61it/s]

Batch 1700 statistics:
  Loss: 0.6230
  Gradient norm: 0.1592
  Score range: [-0.9750, 0.9436]
  Prediction mean: 0.5369


  7%|▋         | 1805/26474 [01:16<15:23, 26.70it/s]

Batch 1800 statistics:
  Loss: 0.6118
  Gradient norm: 0.1620
  Score range: [-0.9704, 0.9359]
  Prediction mean: 0.5293


  7%|▋         | 1903/26474 [01:20<16:02, 25.52it/s]

Batch 1900 statistics:
  Loss: 0.6177
  Gradient norm: 0.1992
  Score range: [-0.9732, 0.9307]
  Prediction mean: 0.5391


  8%|▊         | 2007/26474 [01:24<13:26, 30.33it/s]

Batch 2000 statistics:
  Loss: 0.6292
  Gradient norm: 0.1568
  Score range: [-0.9637, 0.9112]
  Prediction mean: 0.5304


  8%|▊         | 2102/26474 [01:27<15:49, 25.67it/s]

Batch 2100 statistics:
  Loss: 0.6172
  Gradient norm: 0.1531
  Score range: [-0.9628, 0.9354]
  Prediction mean: 0.5355


  8%|▊         | 2203/26474 [01:32<24:23, 16.58it/s]

Batch 2200 statistics:
  Loss: 0.6199
  Gradient norm: 0.1506
  Score range: [-0.9627, 0.9346]
  Prediction mean: 0.5326


  9%|▊         | 2305/26474 [01:36<15:52, 25.37it/s]

Batch 2300 statistics:
  Loss: 0.6147
  Gradient norm: 0.1358
  Score range: [-0.9579, 0.9107]
  Prediction mean: 0.5312


  9%|▉         | 2405/26474 [01:40<13:50, 28.99it/s]

Batch 2400 statistics:
  Loss: 0.6217
  Gradient norm: 0.1244
  Score range: [-0.9505, 0.9392]
  Prediction mean: 0.5323


  9%|▉         | 2505/26474 [01:44<15:55, 25.10it/s]

Batch 2500 statistics:
  Loss: 0.6171
  Gradient norm: 0.1463
  Score range: [-0.9678, 0.9315]
  Prediction mean: 0.5206


 10%|▉         | 2602/26474 [01:48<16:40, 23.86it/s]

Batch 2600 statistics:
  Loss: 0.6172
  Gradient norm: 0.1306
  Score range: [-0.9508, 0.9429]
  Prediction mean: 0.5363


 10%|█         | 2705/26474 [01:53<14:43, 26.90it/s]

Batch 2700 statistics:
  Loss: 0.6205
  Gradient norm: 0.1056
  Score range: [-0.9796, 0.9588]
  Prediction mean: 0.5415


 11%|█         | 2805/26474 [01:56<14:34, 27.07it/s]

Batch 2800 statistics:
  Loss: 0.6165
  Gradient norm: 0.1270
  Score range: [-0.9441, 0.9450]
  Prediction mean: 0.5367


 11%|█         | 2905/26474 [02:00<14:12, 27.64it/s]

Batch 2900 statistics:
  Loss: 0.6005
  Gradient norm: 0.1210
  Score range: [-0.9529, 0.9165]
  Prediction mean: 0.5284


 11%|█▏        | 3005/26474 [02:04<16:02, 24.37it/s]

Batch 3000 statistics:
  Loss: 0.6021
  Gradient norm: 0.1537
  Score range: [-0.9519, 0.9377]
  Prediction mean: 0.5308


 12%|█▏        | 3101/26474 [02:09<22:44, 17.13it/s]

Batch 3100 statistics:
  Loss: 0.6149
  Gradient norm: 0.1252
  Score range: [-0.9553, 0.9437]
  Prediction mean: 0.5390


 12%|█▏        | 3205/26474 [02:14<16:43, 23.18it/s]

Batch 3200 statistics:
  Loss: 0.6007
  Gradient norm: 0.1152
  Score range: [-0.9726, 0.9210]
  Prediction mean: 0.5338


 12%|█▏        | 3305/26474 [02:18<17:25, 22.16it/s]

Batch 3300 statistics:
  Loss: 0.6122
  Gradient norm: 0.1198
  Score range: [-0.9545, 0.9337]
  Prediction mean: 0.5457


 13%|█▎        | 3401/26474 [02:23<19:42, 19.50it/s]

Batch 3400 statistics:
  Loss: 0.6175
  Gradient norm: 0.1488
  Score range: [-0.9459, 0.9194]
  Prediction mean: 0.5369


 13%|█▎        | 3501/26474 [02:28<18:45, 20.42it/s]

Batch 3500 statistics:
  Loss: 0.6161
  Gradient norm: 0.1203
  Score range: [-0.9535, 0.9433]
  Prediction mean: 0.5400


 14%|█▎        | 3605/26474 [02:35<16:51, 22.61it/s]

Batch 3600 statistics:
  Loss: 0.6341
  Gradient norm: 0.1700
  Score range: [-0.9597, 0.9212]
  Prediction mean: 0.5310


 14%|█▍        | 3705/26474 [02:39<15:14, 24.89it/s]

Batch 3700 statistics:
  Loss: 0.6163
  Gradient norm: 0.0921
  Score range: [-0.9557, 0.9517]
  Prediction mean: 0.5390


 14%|█▍        | 3805/26474 [02:43<16:51, 22.41it/s]

Batch 3800 statistics:
  Loss: 0.6197
  Gradient norm: 0.1274
  Score range: [-0.9781, 0.9514]
  Prediction mean: 0.5511


 15%|█▍        | 3905/26474 [02:48<15:55, 23.63it/s]

Batch 3900 statistics:
  Loss: 0.6240
  Gradient norm: 0.1080
  Score range: [-0.9568, 0.9673]
  Prediction mean: 0.5442


 15%|█▌        | 4005/26474 [02:53<25:20, 14.78it/s]

Batch 4000 statistics:
  Loss: 0.6276
  Gradient norm: 0.1529
  Score range: [-0.9506, 0.9269]
  Prediction mean: 0.5327


 16%|█▌        | 4105/26474 [02:57<14:58, 24.89it/s]

Batch 4100 statistics:
  Loss: 0.6260
  Gradient norm: 0.1250
  Score range: [-0.9471, 0.9455]
  Prediction mean: 0.5394


 16%|█▌        | 4205/26474 [03:01<15:16, 24.30it/s]

Batch 4200 statistics:
  Loss: 0.6176
  Gradient norm: 0.1327
  Score range: [-0.9640, 0.9419]
  Prediction mean: 0.5385


 16%|█▋        | 4305/26474 [03:05<14:43, 25.09it/s]

Batch 4300 statistics:
  Loss: 0.6250
  Gradient norm: 0.1025
  Score range: [-0.9656, 0.9590]
  Prediction mean: 0.5412


 17%|█▋        | 4405/26474 [03:09<14:30, 25.34it/s]

Batch 4400 statistics:
  Loss: 0.6172
  Gradient norm: 0.1984
  Score range: [-0.9772, 0.9504]
  Prediction mean: 0.5460


 17%|█▋        | 4504/26474 [03:14<16:21, 22.38it/s]

Batch 4500 statistics:
  Loss: 0.6206
  Gradient norm: 0.1108
  Score range: [-0.9644, 0.9503]
  Prediction mean: 0.5392


 17%|█▋        | 4604/26474 [03:18<13:27, 27.09it/s]

Batch 4600 statistics:
  Loss: 0.6196
  Gradient norm: 0.1458
  Score range: [-0.9625, 0.9402]
  Prediction mean: 0.5375


 18%|█▊        | 4704/26474 [03:21<13:19, 27.24it/s]

Batch 4700 statistics:
  Loss: 0.6183
  Gradient norm: 0.1248
  Score range: [-0.9532, 0.9442]
  Prediction mean: 0.5422


 18%|█▊        | 4804/26474 [03:25<13:53, 26.00it/s]

Batch 4800 statistics:
  Loss: 0.6057
  Gradient norm: 0.1841
  Score range: [-0.9492, 0.9452]
  Prediction mean: 0.5270


 19%|█▊        | 4905/26474 [03:30<21:48, 16.49it/s]

Batch 4900 statistics:
  Loss: 0.6115
  Gradient norm: 0.1042
  Score range: [-0.9673, 0.9383]
  Prediction mean: 0.5351


 19%|█▉        | 5004/26474 [03:34<12:57, 27.62it/s]

Batch 5000 statistics:
  Loss: 0.6136
  Gradient norm: 0.1142
  Score range: [-0.9528, 0.9511]
  Prediction mean: 0.5369


 19%|█▉        | 5104/26474 [03:38<13:55, 25.57it/s]

Batch 5100 statistics:
  Loss: 0.6111
  Gradient norm: 0.1751
  Score range: [-0.9312, 0.9586]
  Prediction mean: 0.5364


 20%|█▉        | 5204/26474 [03:41<12:41, 27.92it/s]

Batch 5200 statistics:
  Loss: 0.6109
  Gradient norm: 0.1433
  Score range: [-0.9518, 0.9301]
  Prediction mean: 0.5314


 20%|██        | 5302/26474 [03:45<13:48, 25.55it/s]

Batch 5300 statistics:
  Loss: 0.6100
  Gradient norm: 0.1525
  Score range: [-0.9866, 0.9516]
  Prediction mean: 0.5365


 20%|██        | 5404/26474 [03:50<15:35, 22.52it/s]

Batch 5400 statistics:
  Loss: 0.6177
  Gradient norm: 0.1575
  Score range: [-0.9549, 0.9290]
  Prediction mean: 0.5318


 21%|██        | 5504/26474 [03:54<13:06, 26.68it/s]

Batch 5500 statistics:
  Loss: 0.6102
  Gradient norm: 0.1052
  Score range: [-0.9655, 0.9569]
  Prediction mean: 0.5387


 21%|██        | 5604/26474 [03:58<12:02, 28.87it/s]

Batch 5600 statistics:
  Loss: 0.6132
  Gradient norm: 0.1262
  Score range: [-0.9582, 0.9439]
  Prediction mean: 0.5349


 22%|██▏       | 5706/26474 [04:01<13:26, 25.74it/s]

Batch 5700 statistics:
  Loss: 0.6247
  Gradient norm: 0.1831
  Score range: [-0.9769, 0.9391]
  Prediction mean: 0.5519


 22%|██▏       | 5805/26474 [04:06<21:02, 16.37it/s]

Batch 5800 statistics:
  Loss: 0.6169
  Gradient norm: 0.1395
  Score range: [-0.9810, 0.9457]
  Prediction mean: 0.5348


 22%|██▏       | 5904/26474 [04:10<12:12, 28.10it/s]

Batch 5900 statistics:
  Loss: 0.6117
  Gradient norm: 0.1665
  Score range: [-0.9491, 0.9593]
  Prediction mean: 0.5308


 23%|██▎       | 6004/26474 [04:14<12:35, 27.10it/s]

Batch 6000 statistics:
  Loss: 0.6226
  Gradient norm: 0.1232
  Score range: [-0.9592, 0.9420]
  Prediction mean: 0.5321


 23%|██▎       | 6104/26474 [04:18<11:56, 28.42it/s]

Batch 6100 statistics:
  Loss: 0.6107
  Gradient norm: 0.1418
  Score range: [-0.9666, 0.9800]
  Prediction mean: 0.5400


 23%|██▎       | 6167/26474 [04:20<14:17, 23.67it/s]


KeyboardInterrupt: 