In [221]:
from dotenv import load_dotenv
import os
import pandas as pd
load_dotenv()
base_dir = os.getenv('BASEDIR')
import re
import pickle as pk
import numpy as np
from tqdm import tqdm
tqdm.pandas()

### Representation 1
- U -> U

#### Node Types
- user

#### Edge Types
- replied_to
- quoted
- liked
- retweeted
- mentioned

### Representation 2
- U -> T -> U
- U -> T <- U

#### Node Types
- user
- tweet

#### Edge Types
- posted (U -> T)
- reply to (U -> T)
- quoted (U -> T)
- liked (U -> T)
- retweeted (U -> T)
- mentions (T -> U)

# Qanda

In [None]:
data = pd.read_csv(os.path.join(base_dir,'data','01_raw_data','qanda','qanda_hetero_graph_data.csv'), dtype=str, names=['tid', 'uid', 'reply_id','reply_uid', 'quoted_id','quoted_uid', 'liked_status', 'retweeted_id','retweeted_uid', 'mention_ids'])
data['mention_ids'] = data['mention_ids'].fillna('').apply(lambda s: list(filter(lambda e: len(e)>0,s.split(';;;'))))

In [None]:
def handle_row_r1(row):
    edges = []
    source = row['uid']
    
    reply_uid = row['reply_uid']
    quoted_uid = row['quoted_uid']
    retweeted_uid = row['retweeted_uid']
    mention_uids = row['mention_ids']
    
    if not pd.isna(reply_uid):
        edges.append({'source' : source, 'target' : reply_uid, 'edge_type' : 'replied_to'})
    if not pd.isna(quoted_uid):
        edges.append({'source' : source, 'target' : quoted_uid, 'edge_type' : 'quoted'})
    if not pd.isna(retweeted_uid):
        edges.append({'source' : source, 'target' : retweeted_uid, 'edge_type' : 'retweeted'})  
    for mention_uid in mention_uids:
        edges.append({'source' : source, 'target' : mention_uid, 'edge_type' : 'mentioned'})  
    return edges
        
def handle_row_r2(row):
    edges = []
    source = row['uid']
    source_id = row['tid']
    
    reply_id = row['reply_id']
    quoted_id = row['quoted_id']
    retweeted_id = row['retweeted_id']
    mention_uids = row['mention_ids']
    
    if not pd.isna(reply_id):
        edges.append({'source' : source, 'target' : reply_id, 'edge_type' : 'replied_to'})
    if not pd.isna(quoted_id):
        edges.append({'source' : source, 'target' : quoted_id, 'edge_type' : 'quoted'})
    if not pd.isna(retweeted_id):
        edges.append({'source' : source, 'target' : retweeted_id, 'edge_type' : 'retweeted'})  
    for mention_uid in mention_uids:
        edges.append({'source' : source_id, 'target' : mention_uid, 'edge_type' : 'mentioned'})
    edges.append({'source': source, 'target': source_id, 'edge_type' : 'posted'})
        
    return edges

In [None]:
r1 = data.progress_apply(handle_row_r1,axis=1)
r1 = [d for r in r1 for d in r]
r1_df = pd.DataFrame(r1)

In [None]:
r2 = data.progress_apply(handle_row_r2,axis=1)
r2 = [d for r in r2 for d in r]
r2_df = pd.DataFrame(r2)

In [None]:
r1_df.to_csv(os.path.join(base_dir,'data','01_raw_data','qanda','qanda_hetero_graph_edges_r1.csv'))
r2_df.to_csv(os.path.join(base_dir,'data','01_raw_data','qanda','qanda_hetero_graph_edges_r2.csv'))

# Ausvotes

In [209]:
data = pd.read_csv(os.path.join(base_dir,'data','01_raw_data','ausvotes','ausvotes_hetero_graph.csv'), dtype=str, names=['tid', 'uid', 'reply_uid','referenced_tweets','conversation_id', 'mention_ids'])
data['mention_ids'] = data['mention_ids'].fillna('').apply(lambda s: list(filter(lambda e: len(e)>0,s.split(';;;'))))

In [210]:
data = data[~data['tid'].isna()]

In [211]:
tid_to_uid = data[['tid','uid']]

In [212]:
data = data[~tid_to_uid.tid.duplicated()]

In [213]:
tid_to_uid = tid_to_uid[~tid_to_uid.tid.duplicated()].set_index('tid')

In [214]:
def handle_referenced_tweets(s):
    edges = []
    for e in s.split(';;;'):
        if len(e)>0:
            typ, tid = tuple(e.split(':'))
            edges.append({'target':tid,'edge_type':typ})
    return edges
data['referenced_tweets'] = data['referenced_tweets'].fillna('').apply(handle_referenced_tweets)

In [215]:
def handle_row_r1(row):
    edges = []
    source = row['uid']
    source_id = row['tid']
    
    referenced_tweets = row['referenced_tweets']
    mention_uids = row['mention_ids']
    
    for rt in referenced_tweets:
        if rt['target'] in tid_to_uid.index:
            target_uid = tid_to_uid.loc[rt['target'],'uid']
            edges.append({'source' : source, 'target' : target_uid, 'edge_type' : rt['edge_type']})
    for mention_uid in mention_uids:
        edges.append({'source' : source, 'target' : mention_uid, 'edge_type' : 'mentioned'})
        
    return edges

In [216]:
def handle_row_r2(row):
    edges = []
    source = row['uid']
    source_id = row['tid']
    if pd.isna(source) or pd.isna(source_id):
        print(row)
        return edges
    
    referenced_tweets = row['referenced_tweets']
    mention_uids = row['mention_ids']
    
    for rt in referenced_tweets:
        edges.append({'source' : source, 'target' : rt['target'], 'edge_type' : rt['edge_type']})
    for mention_uid in mention_uids:
        edges.append({'source' : source_id, 'target' : mention_uid, 'edge_type' : 'mentioned'})
    edges.append({'source': source, 'target': source_id, 'edge_type' : 'posted'})
        
    return edges

In [217]:
r1 = data.progress_apply(handle_row_r1,axis=1)
r1 = [d for r in r1 for d in r]
r1_df = pd.DataFrame(r1)


  0%|                                                                      | 0/5033982 [00:00<?, ?it/s][A
  0%|                                                         | 1/5033982 [00:01<1809:20:14,  1.29s/it][A
  0%|                                                         | 2/5033982 [00:03<2573:40:28,  1.84s/it][A
  0%|                                                         | 3905/5033982 [00:03<49:10, 1704.70it/s][A
  0%|                                                         | 7864/5033982 [00:03<21:19, 3929.10it/s][A
  0%|▏                                                       | 11762/5033982 [00:03<12:36, 6641.44it/s][A
  0%|▏                                                       | 15591/5033982 [00:03<08:32, 9791.85it/s][A
  0%|▏                                                      | 19501/5033982 [00:04<06:13, 13440.42it/s][A
  0%|▎                                                      | 23393/5033982 [00:04<04:49, 17287.60it/s][A
  1%|▎                              

 12%|██████▍                                               | 604384/5033982 [00:18<01:53, 39152.24it/s][A
 12%|██████▌                                               | 608393/5033982 [00:18<01:52, 39432.51it/s][A
 12%|██████▌                                               | 612337/5033982 [00:18<01:52, 39307.81it/s][A
 12%|██████▌                                               | 616318/5033982 [00:18<01:51, 39455.57it/s][A
 12%|██████▋                                               | 620337/5033982 [00:19<01:51, 39675.09it/s][A
 12%|██████▋                                               | 624305/5033982 [00:19<01:51, 39640.57it/s][A
 12%|██████▋                                               | 628270/5033982 [00:19<01:51, 39556.25it/s][A
 13%|██████▊                                               | 632347/5033982 [00:19<01:50, 39916.03it/s][A
 13%|██████▊                                               | 636339/5033982 [00:19<01:50, 39906.57it/s][A
 13%|██████▊                         

 24%|████████████▊                                        | 1214127/5033982 [00:33<01:36, 39775.09it/s][A
 24%|████████████▊                                        | 1218162/5033982 [00:34<01:35, 39945.80it/s][A
 24%|████████████▊                                        | 1222158/5033982 [00:34<01:35, 39948.66it/s][A
 24%|████████████▉                                        | 1226168/5033982 [00:34<01:35, 39993.73it/s][A
 24%|████████████▉                                        | 1230264/5033982 [00:34<01:34, 40279.98it/s][A
 25%|████████████▉                                        | 1234293/5033982 [00:34<01:34, 40165.66it/s][A
 25%|█████████████                                        | 1238320/5033982 [00:34<01:34, 40194.09it/s][A
 25%|█████████████                                        | 1242340/5033982 [00:34<01:34, 40156.73it/s][A
 25%|█████████████                                        | 1246356/5033982 [00:34<01:35, 39852.81it/s][A
 25%|█████████████▏                  

 36%|███████████████████                                  | 1816180/5033982 [00:49<01:21, 39664.42it/s][A
 36%|███████████████████▏                                 | 1820194/5033982 [00:49<01:20, 39805.19it/s][A
 36%|███████████████████▏                                 | 1824224/5033982 [00:49<01:20, 39951.15it/s][A
 36%|███████████████████▏                                 | 1828220/5033982 [00:49<01:20, 39730.25it/s][A
 36%|███████████████████▎                                 | 1832203/5033982 [00:49<01:20, 39759.76it/s][A
 36%|███████████████████▎                                 | 1836198/5033982 [00:49<01:20, 39816.60it/s][A
 37%|███████████████████▎                                 | 1840233/5033982 [00:49<01:19, 39975.97it/s][A
 37%|███████████████████▍                                 | 1844231/5033982 [00:50<01:19, 39916.95it/s][A
 37%|███████████████████▍                                 | 1848263/5033982 [00:50<01:19, 40035.18it/s][A
 37%|███████████████████▌            

 48%|█████████████████████████▋                           | 2437129/5033982 [01:04<01:03, 41127.21it/s][A
 48%|█████████████████████████▋                           | 2441324/5033982 [01:04<01:02, 41372.91it/s][A
 49%|█████████████████████████▋                           | 2445462/5033982 [01:04<01:02, 41243.86it/s][A
 49%|█████████████████████████▊                           | 2449635/5033982 [01:04<01:02, 41388.13it/s][A
 49%|█████████████████████████▊                           | 2453774/5033982 [01:05<01:03, 40814.50it/s][A
 49%|█████████████████████████▉                           | 2457858/5033982 [01:05<01:03, 40346.75it/s][A
 49%|█████████████████████████▉                           | 2461895/5033982 [01:05<01:04, 40111.86it/s][A
 49%|█████████████████████████▉                           | 2465916/5033982 [01:05<01:03, 40140.19it/s][A
 49%|██████████████████████████                           | 2470049/5033982 [01:05<01:03, 40492.30it/s][A
 49%|██████████████████████████      

 60%|████████████████████████████████                     | 3042294/5033982 [01:20<00:50, 39131.87it/s][A
 61%|████████████████████████████████                     | 3046235/5033982 [01:20<00:50, 39214.10it/s][A
 61%|████████████████████████████████                     | 3050210/5033982 [01:20<00:50, 39371.10it/s][A
 61%|████████████████████████████████▏                    | 3054200/5033982 [01:20<00:50, 39528.85it/s][A
 61%|████████████████████████████████▏                    | 3058226/5033982 [01:20<00:49, 39746.17it/s][A
 61%|████████████████████████████████▏                    | 3062259/5033982 [01:20<00:49, 39918.53it/s][A
 61%|████████████████████████████████▎                    | 3066252/5033982 [01:20<00:49, 39782.71it/s][A
 61%|████████████████████████████████▎                    | 3070294/5033982 [01:20<00:49, 39970.64it/s][A
 61%|████████████████████████████████▎                    | 3074296/5033982 [01:20<00:49, 39985.36it/s][A
 61%|████████████████████████████████

 72%|██████████████████████████████████████▎              | 3633605/5033982 [01:35<00:37, 37823.90it/s][A
 72%|██████████████████████████████████████▎              | 3637396/5033982 [01:35<00:36, 37846.89it/s][A
 72%|██████████████████████████████████████▎              | 3641182/5033982 [01:35<00:36, 37705.32it/s][A
 72%|██████████████████████████████████████▍              | 3644997/5033982 [01:35<00:36, 37836.44it/s][A
 72%|██████████████████████████████████████▍              | 3648782/5033982 [01:35<00:36, 37824.05it/s][A
 73%|██████████████████████████████████████▍              | 3652565/5033982 [01:35<00:36, 37714.64it/s][A
 73%|██████████████████████████████████████▍              | 3656337/5033982 [01:35<00:37, 37053.63it/s][A
 73%|██████████████████████████████████████▌              | 3660124/5033982 [01:36<00:36, 37294.44it/s][A
 73%|██████████████████████████████████████▌              | 3663905/5033982 [01:36<00:36, 37445.27it/s][A
 73%|████████████████████████████████

 84%|████████████████████████████████████████████▌        | 4231965/5033982 [01:50<00:20, 39669.44it/s][A
 84%|████████████████████████████████████████████▌        | 4235932/5033982 [01:50<00:20, 39577.19it/s][A
 84%|████████████████████████████████████████████▋        | 4239963/5033982 [01:50<00:19, 39794.86it/s][A
 84%|████████████████████████████████████████████▋        | 4243963/5033982 [01:51<00:19, 39855.07it/s][A
 84%|████████████████████████████████████████████▋        | 4247954/5033982 [01:51<00:19, 39868.15it/s][A
 84%|████████████████████████████████████████████▊        | 4251941/5033982 [01:51<00:19, 39727.69it/s][A
 85%|████████████████████████████████████████████▊        | 4255914/5033982 [01:51<00:19, 39684.11it/s][A
 85%|████████████████████████████████████████████▊        | 4259982/5033982 [01:51<00:19, 39980.94it/s][A
 85%|████████████████████████████████████████████▉        | 4263981/5033982 [01:51<00:19, 39940.33it/s][A
 85%|████████████████████████████████

 96%|██████████████████████████████████████████████████▉  | 4835302/5033982 [02:06<00:05, 39341.21it/s][A
 96%|██████████████████████████████████████████████████▉  | 4839247/5033982 [02:06<00:04, 39372.15it/s][A
 96%|██████████████████████████████████████████████████▉  | 4843185/5033982 [02:06<00:04, 38961.57it/s][A
 96%|███████████████████████████████████████████████████  | 4847151/5033982 [02:06<00:04, 39168.34it/s][A
 96%|███████████████████████████████████████████████████  | 4851112/5033982 [02:06<00:04, 39299.19it/s][A
 96%|███████████████████████████████████████████████████  | 4855137/5033982 [02:06<00:04, 39582.03it/s][A
 97%|███████████████████████████████████████████████████▏ | 4859164/5033982 [02:06<00:04, 39786.19it/s][A
 97%|███████████████████████████████████████████████████▏ | 4863144/5033982 [02:06<00:04, 39601.80it/s][A
 97%|███████████████████████████████████████████████████▏ | 4867174/5033982 [02:06<00:04, 39810.05it/s][A
 97%|████████████████████████████████

In [218]:
r2 = data.progress_apply(handle_row_r2,axis=1)
r2 = [d for r in r2 for d in r]
r2_df = pd.DataFrame(r2)


  0%|                                                                      | 0/5033982 [00:00<?, ?it/s][A
  0%|                                                         | 134/5033982 [00:00<2:55:46, 477.32it/s][A
  0%|                                                         | 447/5033982 [00:00<1:38:55, 848.09it/s][A
  0%|                                                        | 6171/5033982 [00:00<06:04, 13788.46it/s][A
  0%|▏                                                      | 12236/5033982 [00:00<03:18, 25357.70it/s][A
  0%|▏                                                      | 18166/5033982 [00:00<02:26, 34206.23it/s][A
  0%|▎                                                      | 24252/5033982 [00:00<02:00, 41439.05it/s][A
  1%|▎                                                      | 30402/5033982 [00:01<01:46, 47042.04it/s][A
  1%|▍                                                      | 36643/5033982 [00:01<01:37, 51425.33it/s][A
  1%|▍                              

 19%|██████████                                            | 942222/5033982 [00:15<01:06, 61578.13it/s][A
 19%|██████████▏                                           | 948383/5033982 [00:15<01:06, 61372.24it/s][A
 19%|██████████▏                                           | 954790/5033982 [00:15<01:05, 62173.53it/s][A
 19%|██████████▎                                           | 961062/5033982 [00:15<01:05, 62334.58it/s][A
 19%|██████████▍                                           | 967383/5033982 [00:16<01:04, 62594.04it/s][A
 19%|██████████▍                                           | 973656/5033982 [00:16<01:04, 62632.32it/s][A
 19%|██████████▌                                           | 979941/5033982 [00:16<01:04, 62694.82it/s][A
 20%|██████████▌                                           | 986212/5033982 [00:16<01:04, 62323.97it/s][A
 20%|██████████▋                                           | 992579/5033982 [00:16<01:04, 62725.11it/s][A
 20%|██████████▋                     

 37%|███████████████████▊                                 | 1880279/5033982 [00:44<00:51, 61639.09it/s][A
 37%|███████████████████▊                                 | 1886446/5033982 [00:44<00:51, 61485.37it/s][A
 38%|███████████████████▉                                 | 1892597/5033982 [00:44<00:51, 61206.15it/s][A
 38%|███████████████████▉                                 | 1898840/5033982 [00:44<00:50, 61570.77it/s][A
 38%|████████████████████                                 | 1905061/5033982 [00:44<00:50, 61760.71it/s][A
 38%|████████████████████                                 | 1911338/5033982 [00:44<00:50, 62061.78it/s][A
 38%|████████████████████▏                                | 1917546/5033982 [00:44<00:50, 61775.18it/s][A
 38%|████████████████████▎                                | 1923833/5033982 [00:44<00:50, 62100.64it/s][A
 38%|████████████████████▎                                | 1930130/5033982 [00:45<00:49, 62359.71it/s][A
 38%|████████████████████▍           

 57%|█████████████████████████████▉                       | 2844648/5033982 [00:59<00:35, 60876.78it/s][A
 57%|██████████████████████████████                       | 2850976/5033982 [00:59<00:35, 61581.65it/s][A
 57%|██████████████████████████████                       | 2857226/5033982 [00:59<00:35, 61852.55it/s][A
 57%|██████████████████████████████▏                      | 2863456/5033982 [00:59<00:35, 61928.09it/s][A
 57%|██████████████████████████████▏                      | 2869747/5033982 [00:59<00:34, 62219.87it/s][A
 57%|██████████████████████████████▎                      | 2875991/5033982 [01:00<00:34, 62127.92it/s][A
 57%|██████████████████████████████▎                      | 2882423/5033982 [01:00<00:34, 62779.38it/s][A
 57%|██████████████████████████████▍                      | 2888730/5033982 [01:00<00:34, 62866.15it/s][A
 58%|██████████████████████████████▍                      | 2895025/5033982 [01:00<00:34, 62674.53it/s][A
 58%|██████████████████████████████▌ 

 76%|████████████████████████████████████████▏            | 3811929/5033982 [01:14<00:19, 63089.87it/s][A
 76%|████████████████████████████████████████▏            | 3818241/5033982 [01:15<00:19, 62741.53it/s][A
 76%|████████████████████████████████████████▎            | 3824721/5033982 [01:15<00:19, 63354.59it/s][A
 76%|████████████████████████████████████████▎            | 3831229/5033982 [01:15<00:18, 63868.44it/s][A
 76%|████████████████████████████████████████▍            | 3837618/5033982 [01:15<00:18, 63809.41it/s][A
 76%|████████████████████████████████████████▍            | 3844000/5033982 [01:15<00:18, 63768.03it/s][A
 76%|████████████████████████████████████████▌            | 3850378/5033982 [01:15<00:18, 63444.91it/s][A
 77%|████████████████████████████████████████▌            | 3856868/5033982 [01:15<00:18, 63877.35it/s][A
 77%|████████████████████████████████████████▋            | 3863279/5033982 [01:15<00:18, 63946.85it/s][A
 77%|████████████████████████████████

 95%|██████████████████████████████████████████████████▏  | 4768378/5033982 [01:30<00:04, 61923.65it/s][A
 95%|██████████████████████████████████████████████████▎  | 4774572/5033982 [01:30<00:04, 61598.08it/s][A
 95%|██████████████████████████████████████████████████▎  | 4780733/5033982 [01:30<00:04, 61590.02it/s][A
 95%|██████████████████████████████████████████████████▍  | 4786921/5033982 [01:30<00:04, 61674.03it/s][A
 95%|██████████████████████████████████████████████████▍  | 4793143/5033982 [01:30<00:03, 61834.97it/s][A
 95%|██████████████████████████████████████████████████▌  | 4799343/5033982 [01:30<00:03, 61882.32it/s][A
 95%|██████████████████████████████████████████████████▌  | 4805532/5033982 [01:30<00:03, 61639.66it/s][A
 96%|██████████████████████████████████████████████████▋  | 4811873/5033982 [01:30<00:03, 62165.84it/s][A
 96%|██████████████████████████████████████████████████▋  | 4818199/5033982 [01:31<00:03, 62490.19it/s][A
 96%|████████████████████████████████

In [219]:
r1_df.to_csv(os.path.join(base_dir,'data','01_raw_data','ausvotes','ausvotes_hetero_graph_edges_r1.csv'))
r2_df.to_csv(os.path.join(base_dir,'data','01_raw_data','ausvotes','ausvotes_hetero_graph_edges_r2.csv'))

### Representation 1

In [125]:
import torch
from torch_geometric.data import HeteroData
import torch_geometric.transforms as T
import pandas as pd
from sklearn.model_selection import train_test_split
# dataset= 'qanda'
# ground_truth = 'URL_LR'

datasets = ['qanda', 'ausvotes']
ground_truths = ['URL_LR']

for dataset in datasets:
    for gt in ground_truths:
        r1_df = pd.read_csv(os.path.join(base_dir,'data','01_raw_data',dataset,dataset+'_hetero_graph_edges_r1.csv'), dtype=str)
        with open(os.path.join(base_dir,'data','03_processed',dataset,'ground_truth',dataset+'_'+gt+'__per_user.pk'), 'rb') as rf:
            ground_truth =  pk.load(rf).reset_index(drop=False).drop(['level_1'], axis=1).set_index('uid')['urls']

        user_to_index = pd.DataFrame({'uid': np.unique(np.hstack([r1_df['source'].values,r1_df['target'].values]))}, dtype=str).set_index('uid')
        user_to_index['numeric_index'] = np.arange(0,len(user_to_index))

        data = HeteroData()

        def get_from_gt(uid):
            try:
                return ground_truth.loc[uid]
            except Exception as e:
                return -1
        y = user_to_index.index.to_series().apply(get_from_gt)

        labelled_users = np.array(y[y != -1].index)

        X_train, X_test, y_train, y_test = train_test_split(labelled_users, np.arange(0, len(labelled_users)), test_size=0.2, random_state=1)
        X_train, X_val, y_train, y_val = train_test_split(X_train, y_train, test_size=0.25, random_state=1)

        train_mask = user_to_index.index.to_series().apply(lambda uid: uid in X_train)
        val_mask = user_to_index.index.to_series().apply(lambda uid: uid in X_val)
        test_mask = user_to_index.index.to_series().apply(lambda uid: uid in X_test)

        data['user'].x = torch.tensor(np.ones(len(user_to_index)))

        data['user'].y = torch.tensor(y.values)

        data['user'].train_mask = torch.tensor(train_mask.values)
        data['user'].val_mask = torch.tensor(val_mask.values)
        data['user'].test_mask = torch.tensor(test_mask.values)

        for edge_type in ['quoted','retweeted','mentioned','replied_to']:
            filtered_df = r1_df[r1_df['edge_type'] == edge_type]
            source_col = filtered_df['source'].apply(lambda uid: user_to_index.loc[uid,'numeric_index']).values
            target_col = filtered_df['target'].apply(lambda uid: user_to_index.loc[uid,'numeric_index']).values
            data['user',edge_type, 'user'].edge_index = torch.tensor(np.vstack([source_col,target_col]))

        data['user'].num_nodes = len(user_to_index)
        path_save =  os.path.join(base_dir,'data','01_raw_data',dataset,dataset+'_'+gt+'_hetero_graph_r1.pt')
        torch.save(data, path_save)

### Representation 2

In [220]:
import torch
from torch_geometric.data import HeteroData

import pandas as pd
from sklearn.model_selection import train_test_split

# dataset= 'qanda'
# ground_truth = 'URL_LR'
datasets = ['ausvotes','qanda']
ground_truths = ['URL_LR']

for dataset in datasets:
    for gt in ground_truths:
        r2_df = pd.read_csv(os.path.join(base_dir,'data','01_raw_data',dataset,dataset+'_hetero_graph_edges_r2.csv'), dtype=str)
        with open(os.path.join(base_dir,'data','03_processed',dataset,'ground_truth',dataset+'_'+gt+'__per_user.pk'), 'rb') as rf:
            ground_truth =  pk.load(rf).reset_index(drop=False).drop(['level_1'], axis=1).set_index('uid')['urls']

        mentioned_r2_df = r2_df[r2_df['edge_type'] == 'mentioned']
        other_r2_df =  r2_df[r2_df['edge_type'] != 'mentioned']

        user_indices = np.unique(np.hstack([mentioned_r2_df['target'].values,other_r2_df['source'].values]))
        tweet_indices = np.unique(np.hstack([mentioned_r2_df['source'].values,other_r2_df['target'].values]))

        user_to_index = pd.DataFrame({'uid': user_indices}).set_index('uid')
        user_to_index['numeric_index'] = np.arange(0,len(user_to_index))

        tweet_to_index = pd.DataFrame({'tid': tweet_indices}).set_index('tid')
        tweet_to_index['numeric_index'] = np.arange(0,len(tweet_to_index))
        
        def get_from_gt(uid):
            try:
                return ground_truth.loc[uid]
            except Exception as e:
                return -1
        user_y = user_to_index.index.to_series().apply(get_from_gt)

        labelled_users = np.array(user_y[user_y != -1].index)

        X_train, X_test, y_train, y_test = train_test_split(labelled_users, np.arange(0, len(labelled_users)), test_size=0.2, random_state=1)
        X_train, X_val, y_train, y_val = train_test_split(X_train, y_train, test_size=0.25, random_state=1)

        train_mask = user_to_index.index.to_series().apply(lambda uid: uid in X_train)
        val_mask = user_to_index.index.to_series().apply(lambda uid: uid in X_val)
        test_mask = user_to_index.index.to_series().apply(lambda uid: uid in X_test)

        data = HeteroData()

        data['user'].x = torch.tensor(np.ones(len(user_to_index)))
        data['tweet'].x = torch.tensor(np.ones(len(tweet_to_index)))
        
        data['user'].y = torch.tensor(user_y.values)

        data['user'].train_mask = torch.tensor(train_mask.values)
        data['user'].val_mask = torch.tensor(val_mask.values)
        data['user'].test_mask = torch.tensor(test_mask.values)

        for edge_type in ['quoted','retweeted','replied_to', 'posted']:
            filtered_df = r2_df[r2_df['edge_type'] == edge_type]
            source_col = filtered_df['source'].apply(lambda uid: user_to_index.loc[uid,'numeric_index']).values
            target_col = filtered_df['target'].apply(lambda uid: tweet_to_index.loc[uid,'numeric_index']).values
            try:
                data['user',edge_type, 'tweet'].edge_index = torch.tensor(np.vstack([source_col,target_col]))
            except Exception as e:
                print(edge_type)
                print(source_col)
                print(target_col)
                print(np.vstack([source_col,target_col]))
                raise e

        edge_type = 'mentioned'
        filtered_df = r2_df[r2_df['edge_type'] == edge_type]
        source_col = filtered_df['source'].apply(lambda uid: tweet_to_index.loc[uid,'numeric_index']).values
        target_col = filtered_df['target'].apply(lambda uid: user_to_index.loc[uid,'numeric_index']).values
        data['tweet',edge_type, 'user'].edge_index = torch.tensor(np.vstack([source_col,target_col]))

        data['user'].num_nodes = len(user_to_index)
        data['tweet'].num_nodes = len(tweet_to_index)
        path_save =  os.path.join(base_dir,'data','01_raw_data',dataset,dataset+'_'+gt+'_hetero_graph_r2.pt')
        torch.save(data, path_save)