Skip to content

Commit

Permalink
add bucketing
Browse files Browse the repository at this point in the history
  • Loading branch information
tsepaole committed Jan 10, 2024
1 parent 137808b commit 815cf5a
Show file tree
Hide file tree
Showing 5 changed files with 101 additions and 57 deletions.
11 changes: 1 addition & 10 deletions configs/GPS/cocosuperpixels-GPS.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ model:
type: GPSModel
loss_fun: weighted_cross_entropy
gt:
layer_type: CustomGatedGCN+Mamba_Cluster #Transformer #Performer
layer_type: CustomGatedGCN+Mamba_Bucket #Transformer #Performer
layers: 4
n_heads: 8
dim_hidden: 96 # `gt.dim_hidden` must match `gnn.dim_inner`
Expand All @@ -62,15 +62,6 @@ optim:
max_epoch: 300
scheduler: cosine_with_warmup
num_warmup_epochs: 10
# optim:
# clip_grad_norm: True
# optimizer: adamW
# weight_decay: 0.05
# base_lr: 0.001
# max_epoch: 300
# scheduler: cosine_with_warmup
# num_warmup_epochs: 10

#optim:
# optimizer: adamW
# weight_decay: 0.0
Expand Down
21 changes: 3 additions & 18 deletions configs/GPS/peptides-func-GPS.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,6 @@ posenc_LapPE:
dim_pe: 16
layers: 2
raw_norm_type: none
# posenc_RWSE:
# enable: True
# kernel:
# times_func: range(1,17)
# model: Linear
# dim_pe: 20
# raw_norm_type: BatchNorm
train:
mode: custom
batch_size: 128
Expand All @@ -43,7 +36,7 @@ model:
loss_fun: cross_entropy
graph_pooling: mean
gt:
layer_type: CustomGatedGCN+Mamba_Hybrid_Degree #Mamba_Hybrid_Degree
layer_type: CustomGatedGCN+Mamba_Bucket
n_heads: 4
dim_hidden: 96 # `gt.dim_hidden` must match `gnn.dim_inner`
dropout: 0.0
Expand All @@ -58,19 +51,11 @@ gnn:
batchnorm: True
act: relu
dropout: 0.0
# optim:
# clip_grad_norm: True
# optimizer: adamW
# weight_decay: 0.0
# base_lr: 0.0003
# max_epoch: 200
# scheduler: cosine_with_warmup
# num_warmup_epochs: 10
optim:
clip_grad_norm: True
optimizer: adamW
weight_decay: 0.01
base_lr: 0.001
weight_decay: 0.0
base_lr: 0.0003
max_epoch: 200
scheduler: cosine_with_warmup
num_warmup_epochs: 10
Expand Down
14 changes: 3 additions & 11 deletions configs/GPS/peptides-struct-GPS.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ model:
loss_fun: l1
graph_pooling: mean
gt:
layer_type: CustomGatedGCN+Mamba_Cluster #Mamba_Hybrid_Degree #Mamba
layer_type: CustomGatedGCN+Mamba_Cluster_Bucket #Mamba_Hybrid_Degree #Mamba
layers: 4
n_heads: 4
dim_hidden: 96 # `gt.dim_hidden` must match `gnn.dim_inner`
Expand All @@ -53,19 +53,11 @@ gnn:
batchnorm: True
act: relu
dropout: 0.0
#optim:
# clip_grad_norm: True
# optimizer: adamW
# weight_decay: 0.0
# base_lr: 0.0003
# max_epoch: 200
# scheduler: cosine_with_warmup
# num_warmup_epochs: 10
optim:
clip_grad_norm: True
optimizer: adamW
weight_decay: 0.01
base_lr: 0.001
weight_decay: 0.0
base_lr: 0.0003
max_epoch: 200
scheduler: cosine_with_warmup
num_warmup_epochs: 10
Expand Down
25 changes: 8 additions & 17 deletions configs/GPS/vocsuperpixels-GPS.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
out_dir: results
metric_best: f1
wandb:
use: True
project: PascalVOC-SP
entity: tf-map
dataset:
Expand Down Expand Up @@ -36,7 +35,7 @@ model:
type: GPSModel
loss_fun: weighted_cross_entropy
gt:
layer_type: CustomGatedGCN+Mamba_Cluster
layer_type: CustomGatedGCN+Mamba_Bucket
layers: 4
n_heads: 8
dim_hidden: 96 # `gt.dim_hidden` must match `gnn.dim_inner`
Expand All @@ -54,22 +53,14 @@ gnn:
dropout: 0.0
agg: mean
normalize_adj: False
# optim:
# clip_grad_norm: True
# optimizer: adamW
# weight_decay: 0.01
# base_lr: 0.001
# max_epoch: 300
# scheduler: cosine_with_warmup
# num_warmup_epochs: 10
optim:
clip_grad_norm: True
optimizer: adamW
weight_decay: 0.0
base_lr: 0.0005
max_epoch: 300
scheduler: cosine_with_warmup
num_warmup_epochs: 10
clip_grad_norm: True
optimizer: adamW
weight_decay: 0.0
base_lr: 0.0005
max_epoch: 300
scheduler: cosine_with_warmup
num_warmup_epochs: 10
#optim:
# optimizer: adamW
# weight_decay: 0.0
Expand Down
87 changes: 86 additions & 1 deletion graphgps/layer/gps_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ def __init__(self, dim_h,
self.layer_norm = layer_norm
self.batch_norm = batch_norm
self.equivstable_pe = equivstable_pe
self.NUM_BUCKETS = 2

# Local message-passing model.
if local_gnn_type == 'None':
Expand Down Expand Up @@ -338,6 +339,7 @@ def forward(self, batch):
h_attn = self.self_attn(h_dense)[mask][h_ind_perm_reverse]
mamba_arr.append(h_attn)
h_attn = sum(mamba_arr) / 5

elif 'Mamba_Hybrid_Degree' in self.global_model_type:
if batch.split == 'train':
h_ind_perm = permute_within_batch(batch.batch)
Expand Down Expand Up @@ -378,6 +380,7 @@ def forward(self, batch):
#h_attn = self.self_attn(h_dense)[mask][h_ind_perm_reverse]
mamba_arr.append(h_attn)
h_attn = sum(mamba_arr) / 5

elif self.global_model_type == 'Mamba_Eigen':
deg = degree(batch.edge_index[0], batch.x.shape[0]).to(torch.long)
centrality = batch.EigCentrality
Expand Down Expand Up @@ -419,7 +422,7 @@ def forward(self, batch):
unique_cluster_n = len(torch.unique(batch.LouvainCluster))
permuted_louvain = torch.zeros(batch.LouvainCluster.shape).long().to(batch.LouvainCluster.device)
random_permute = torch.randperm(unique_cluster_n+1).long().to(batch.LouvainCluster.device)
for i in range(len(torch.unique(batch.LouvainCluster))):
for i in range(unique_cluster_n):
indices = torch.nonzero(batch.LouvainCluster == i).squeeze()
permuted_louvain[indices] = random_permute[i]
#h_ind_perm_1 = lexsort([deg[h_ind_perm], permuted_louvain[h_ind_perm], batch.batch[h_ind_perm]])
Expand Down Expand Up @@ -453,11 +456,93 @@ def forward(self, batch):
h_attn = self.self_attn(h_dense)[mask][h_ind_perm_reverse]
mamba_arr.append(h_attn)
h_attn = sum(mamba_arr) / 5

elif self.global_model_type == 'Mamba_Augment':
aug_idx, aug_mask = augment_seq(batch.edge_index, batch.batch, 3)
h_dense, mask = to_dense_batch(h[aug_idx], batch.batch[aug_idx])
aug_idx_reverse = torch.nonzero(aug_mask).squeeze()
h_attn = self.self_attn(h_dense)[mask][aug_idx_reverse]

elif self.global_model_type == 'Mamba_Hybrid_Degree_Bucket':
if batch.split == 'train':
h_ind_perm = permute_within_batch(batch.batch)
deg = degree(batch.edge_index[0], batch.x.shape[0]).to(torch.long)
indices_arr, emb_arr = [],[]
for i in range(self.NUM_BUCKETS):
ind_i = h_ind_perm[h_ind_perm%self.NUM_BUCKETS==i]
h_ind_perm_sort = lexsort([deg[ind_i], batch.batch[ind_i]])
h_ind_perm_i = ind_i[h_ind_perm_sort]
h_dense, mask = to_dense_batch(h[h_ind_perm_i], batch.batch[h_ind_perm_i])
h_dense = self.self_attn(h_dense)[mask]
indices_arr.append(h_ind_perm_i)
emb_arr.append(h_dense)
h_ind_perm_reverse = torch.argsort(torch.cat(indices_arr))
h_attn = torch.cat(emb_arr)[h_ind_perm_reverse]
else:
mamba_arr = []
for i in range(5):
h_ind_perm = permute_within_batch(batch.batch)
deg = degree(batch.edge_index[0], batch.x.shape[0]).to(torch.long)
indices_arr, emb_arr = [],[]
for i in range(self.NUM_BUCKETS):
ind_i = h_ind_perm[h_ind_perm%self.NUM_BUCKETS==i]
h_ind_perm_sort = lexsort([deg[ind_i], batch.batch[ind_i]])
h_ind_perm_i = ind_i[h_ind_perm_sort]
h_dense, mask = to_dense_batch(h[h_ind_perm_i], batch.batch[h_ind_perm_i])
h_dense = self.self_attn(h_dense)[mask]
indices_arr.append(h_ind_perm_i)
emb_arr.append(h_dense)
h_ind_perm_reverse = torch.argsort(torch.cat(indices_arr))
h_attn = torch.cat(emb_arr)[h_ind_perm_reverse]
mamba_arr.append(h_attn)
h_attn = sum(mamba_arr) / 5

elif self.global_model_type == 'Mamba_Cluster_Bucket':
if batch.split == 'train':
# print(h.device, batch.batch.device)
# batch.LouvainCluster = batch.LouvainCluster.to('cpu')
unique_cluster_n = len(torch.unique(batch.LouvainCluster))
h_ind_perm = permute_within_batch(batch.LouvainCluster) # permute withing batch + clusters
# h_ind_perm = h_ind_perm.to(h.device)
# batch.LouvainCluster = batch.LouvainCluster.to(h.device)
deg = degree(batch.edge_index[0], batch.x.shape[0]).to(torch.long)
h_ind_perm_sort = lexsort([deg[h_ind_perm], batch.LouvainCluster[h_ind_perm]])
sorted_indices = h_ind_perm[h_ind_perm_sort]

indices_arr, emb_arr = [],[]
for i in range(unique_cluster_n):
ind_i = sorted_indices[batch.LouvainCluster==i]
h_dense, mask = to_dense_batch(h[ind_i], batch.batch[ind_i])
h_dense = self.self_attn(h_dense)[mask]
indices_arr.append(ind_i)
emb_arr.append(h_dense)

h_ind_perm_reverse = torch.argsort(torch.cat(indices_arr))
h_attn = torch.cat(emb_arr)[h_ind_perm_reverse]
else:
mamba_arr = []
for i in range(5):
# batch.LouvainCluster = batch.LouvainCluster.to('cpu')
unique_cluster_n = len(torch.unique(batch.LouvainCluster))
h_ind_perm = permute_within_batch(batch.LouvainCluster) # permute withing batch + clusters
# h_ind_perm = h_ind_perm.to(h.device)
# batch.LouvainCluster = batch.LouvainCluster.to(h.device)
deg = degree(batch.edge_index[0], batch.x.shape[0]).to(torch.long)
h_ind_perm_sort = lexsort([deg[h_ind_perm], batch.LouvainCluster[h_ind_perm]])
sorted_indices = h_ind_perm[h_ind_perm_sort]

indices_arr, emb_arr = [],[]
for i in range(unique_cluster_n):
ind_i = sorted_indices[batch.LouvainCluster==i]
h_dense, mask = to_dense_batch(h[ind_i], batch.batch[ind_i])
h_dense = self.self_attn(h_dense)[mask]
indices_arr.append(ind_i)
emb_arr.append(h_dense)

h_ind_perm_reverse = torch.argsort(torch.cat(indices_arr))
h_attn = torch.cat(emb_arr)[h_ind_perm_reverse]
mamba_arr.append(h_attn)
h_attn = sum(mamba_arr) / 5
else:
raise RuntimeError(f"Unexpected {self.global_model_type}")

Expand Down

0 comments on commit 815cf5a

Please sign in to comment.