Skip to content

Commit

Permalink
Init submit
Browse files Browse the repository at this point in the history
  • Loading branch information
cszhangzhen committed Feb 16, 2024
1 parent 8862a11 commit d721f51
Show file tree
Hide file tree
Showing 10 changed files with 1,455 additions and 1 deletion.
Binary file added .DS_Store
Binary file not shown.
32 changes: 31 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1 +1,31 @@
# GraphCTA
# GraphCTA
Collaborate to Adapt: Source-Free Graph Domain Adaptation via Bi-directional Adaptation (WWW 2024)

![](https://github.com/cszhangzhen/GraphCTA/blob/main/fig/model.png)

This is a PyTorch implementation of the GraphCTA algorithm, which tries to address the domain adaptation problem without accessing the labelled source graph. It performs model adaptation and graph adaptation collaboratively through a series of procedures: (1) conduct model adaptation based on node's neighborhood predictions in target graph considering both local and global information; (2) perform graph adaptation by updating graph structure and node attributes via neighborhood constrastive learning; and (3) the updated graph serves as an input to facilitate the subsequent iteration of model adaptation, thereby establishing a collaborative loop between model adaptation and graph adaptation.


## Requirements
* python3.8
* pytorch==2.0.0
* torch-scatter==2.1.1+pt20cu118
* torch-sparse==0.6.17+pt20cu118
* torch-cluster==1.6.1+pt20cu118
* torch-geometric==2.3.1
* numpy==1.24.3
* scipy==1.10.1
* tqdm==4.65.0

## Datasets
Datasets used in the paper are all publicly available datasets. You can find [Elliptic](https://www.kaggle.com/datasets/ellipticco/elliptic-data-set), [Twitch](https://github.com/benedekrozemberczki/datasets#twitch-social-networks) and [Citation](https://github.com/yuntaodu/ASN/tree/main/data) via the links.

## Quick Start:
Just execuate the following command for source model pre-training:
```
python train_source.py
```
Then, execuate the following command for adaptation:
```
python train_target.py
```
Binary file added data/Citation.zip
Binary file not shown.
271 changes: 271 additions & 0 deletions datasets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,271 @@
import os.path as osp
import torch
import numpy as np
from torch_geometric.data import InMemoryDataset, Data
from torch_geometric.io import read_txt_array
import torch.nn.functional as F

import scipy
import pickle as pkl
import csv
import json

import warnings
warnings.filterwarnings('ignore', category=DeprecationWarning)


class CitationDataset(InMemoryDataset):
def __init__(self,
root,
name,
transform=None,
pre_transform=None,
pre_filter=None):
self.name = name
self.root = root
super(CitationDataset, self).__init__(root, transform, pre_transform, pre_filter)

self.data, self.slices = torch.load(self.processed_paths[0])

@property
def raw_file_names(self):
return ["docs.txt", "edgelist.txt", "labels.txt"]

@property
def processed_file_names(self):
return ['data.pt']

def download(self):
pass

def process(self):
edge_path = osp.join(self.raw_dir, '{}_edgelist.txt'.format(self.name))
edge_index = read_txt_array(edge_path, sep=',', dtype=torch.long).t()

docs_path = osp.join(self.raw_dir, '{}_docs.txt'.format(self.name))
f = open(docs_path, 'rb')
content_list = []
for line in f.readlines():
line = str(line, encoding="utf-8")
content_list.append(line.split(","))
x = np.array(content_list, dtype=float)
x = torch.from_numpy(x).to(torch.float)

label_path = osp.join(self.raw_dir, '{}_labels.txt'.format(self.name))
f = open(label_path, 'rb')
content_list = []
for line in f.readlines():
line = str(line, encoding="utf-8")
line = line.replace("\r", "").replace("\n", "")
content_list.append(line)
y = np.array(content_list, dtype=int)
y = torch.from_numpy(y).to(torch.int64)

data_list = []
data = Data(edge_index=edge_index, x=x, y=y)

random_node_indices = np.random.permutation(y.shape[0])
training_size = int(len(random_node_indices) * 0.8)
val_size = int(len(random_node_indices) * 0.1)
train_node_indices = random_node_indices[:training_size]
val_node_indices = random_node_indices[training_size:training_size + val_size]
test_node_indices = random_node_indices[training_size + val_size:]

train_masks = torch.zeros([y.shape[0]], dtype=torch.bool)
train_masks[train_node_indices] = 1
val_masks = torch.zeros([y.shape[0]], dtype=torch.bool)
val_masks[val_node_indices] = 1
test_masks = torch.zeros([y.shape[0]], dtype=torch.bool)
test_masks[test_node_indices] = 1

data.train_mask = train_masks
data.val_mask = val_masks
data.test_mask = test_masks

if self.pre_transform is not None:
data = self.pre_transform(data)

data_list.append(data)

data, slices = self.collate([data])

torch.save((data, slices), self.processed_paths[0])


class EllipticDataset(InMemoryDataset):
def __init__(self,
root,
name,
transform=None,
pre_transform=None,
pre_filter=None):
self.name = name
self.root = root
super(EllipticDataset, self).__init__(root, transform, pre_transform, pre_filter)

self.data, self.slices = torch.load(self.processed_paths[0])

@property
def raw_file_names(self):
return [".pkl"]

@property
def processed_file_names(self):
return ['data.pt']

def download(self):
pass

def process(self):
path = osp.join(self.raw_dir, '{}.pkl'.format(self.name))
result = pkl.load(open(path, 'rb'))
A, label, features = result
label = label + 1
edge_index = torch.tensor(np.array(A.nonzero()), dtype=torch.long)
features = np.array(features)
x = torch.from_numpy(features).to(torch.float)
y = torch.tensor(label).to(torch.int64)

data_list = []
data = Data(edge_index=edge_index, x=x, y=y)

random_node_indices = np.random.permutation(y.shape[0])
training_size = int(len(random_node_indices) * 0.8)
val_size = int(len(random_node_indices) * 0.1)
train_node_indices = random_node_indices[:training_size]
val_node_indices = random_node_indices[training_size:training_size + val_size]
test_node_indices = random_node_indices[training_size + val_size:]

train_masks = torch.zeros([y.shape[0]], dtype=torch.bool)
train_masks[train_node_indices] = 1
val_masks = torch.zeros([y.shape[0]], dtype=torch.bool)
val_masks[val_node_indices] = 1
test_masks = torch.zeros([y.shape[0]], dtype=torch.bool)
test_masks[test_node_indices] = 1

data.train_mask = train_masks
data.val_mask = val_masks
data.test_mask = test_masks

if self.pre_transform is not None:
data = self.pre_transform(data)

data_list.append(data)

data, slices = self.collate([data])

torch.save((data, slices), self.processed_paths[0])


class TwitchDataset(InMemoryDataset):
def __init__(self,
root,
name,
transform=None,
pre_transform=None,
pre_filter=None):
self.name = name
self.root = root
super(TwitchDataset, self).__init__(root, transform, pre_transform, pre_filter)

self.data, self.slices = torch.load(self.processed_paths[0])

@property
def raw_file_names(self):
return ["edges.csv, features.json, target.csv"]

@property
def processed_file_names(self):
return ['data.pt']

def download(self):
pass

def load_twitch(self, lang):
assert lang in ('DE', 'EN', 'FR'), 'Invalid dataset'
filepath = self.raw_dir
label = []
node_ids = []
src = []
targ = []
uniq_ids = set()
with open(f"{filepath}/musae_{lang}_target.csv", 'r') as f:
reader = csv.reader(f)
next(reader)
for row in reader:
node_id = int(row[5])
# handle FR case of non-unique rows
if node_id not in uniq_ids:
uniq_ids.add(node_id)
label.append(int(row[2]=="True"))
node_ids.append(int(row[5]))

node_ids = np.array(node_ids, dtype=np.int32)

with open(f"{filepath}/musae_{lang}_edges.csv", 'r') as f:
reader = csv.reader(f)
next(reader)
for row in reader:
src.append(int(row[0]))
targ.append(int(row[1]))

with open(f"{filepath}/musae_{lang}_features.json", 'r') as f:
j = json.load(f)

src = np.array(src)
targ = np.array(targ)
label = np.array(label)

inv_node_ids = {node_id:idx for (idx, node_id) in enumerate(node_ids)}
reorder_node_ids = np.zeros_like(node_ids)
for i in range(label.shape[0]):
reorder_node_ids[i] = inv_node_ids[i]

n = label.shape[0]
A = scipy.sparse.csr_matrix((np.ones(len(src)), (np.array(src), np.array(targ))), shape=(n,n))
features = np.zeros((n,3170))
for node, feats in j.items():
if int(node) >= n:
continue
features[int(node), np.array(feats, dtype=int)] = 1
new_label = label[reorder_node_ids]
label = new_label

return A, label, features

def process(self):
A, label, features = self.load_twitch(self.name)
edge_index = torch.tensor(np.array(A.nonzero()), dtype=torch.long)
features = np.array(features)
x = torch.from_numpy(features).to(torch.float)
y = torch.from_numpy(label).to(torch.int64)

data_list = []
data = Data(edge_index=edge_index, x=x, y=y)

random_node_indices = np.random.permutation(y.shape[0])
training_size = int(len(random_node_indices) * 0.8)
val_size = int(len(random_node_indices) * 0.1)
train_node_indices = random_node_indices[:training_size]
val_node_indices = random_node_indices[training_size:training_size + val_size]
test_node_indices = random_node_indices[training_size + val_size:]

train_masks = torch.zeros([y.shape[0]], dtype=torch.bool)
train_masks[train_node_indices] = 1
val_masks = torch.zeros([y.shape[0]], dtype=torch.bool)
val_masks[val_node_indices] = 1
test_masks = torch.zeros([y.shape[0]], dtype=torch.bool)
test_masks[test_node_indices] = 1

data.train_mask = train_masks
data.val_mask = val_masks
data.test_mask = test_masks

if self.pre_transform is not None:
data = self.pre_transform(data)

data_list.append(data)

data, slices = self.collate([data])

torch.save((data, slices), self.processed_paths[0])
Binary file added fig/model.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading

0 comments on commit d721f51

Please sign in to comment.