In [9]:
from src.nets.conv import GridEmbedding
import torch
from copy import deepcopy

In [39]:
# Example usage:
device = torch.device('cuda')
bs, L, n_cat, C = 4, 8, 10, 16
dummy_input = torch.randint(0, n_cat, (bs, L, L)).to(device)
t = torch.rand(size=(bs,)).to(device)
model = GridEmbedding(n_cat, C).to(device)
embedded_output = model(dummy_input)
print("Output shape:", embedded_output.shape)

Output shape: torch.Size([4, 16, 8, 8])


In [40]:
from src.nets.conv import PottsLocEquivConvNet

In [41]:
potts_net = PottsLocEquivConvNet(n_cat=n_cat,kernel_sizes=[3,5,7], num_channels=4).to(device)

In [42]:
output = potts_net(dummy_input,t)

Check local equivariance.

In [43]:
idx = 3
idy = 4
old_cat = dummy_input[:,idx,idy]
cp_dummy_input = deepcopy(dummy_input)
new_cat = torch.randint(0,n_cat,size=old_cat.shape)
cp_dummy_input[:,idx,idy] = new_cat
cp_output = potts_net(cp_dummy_input,t)

print("Flow from x to neighbors: ", output[torch.arange(len(new_cat)),new_cat,idx,idy])
print("Flow neighbors to x: ", cp_output[torch.arange(len(new_cat)),old_cat,idx,idy])

Flow from x to neighbors:  tensor([-0.6198,  1.8594,  0.0296, -0.1962], device='cuda:0',
       grad_fn=<IndexBackward0>)
Flow neighbors to x:  tensor([ 0.6198, -1.8594, -0.0296,  0.1962], device='cuda:0',
       grad_fn=<IndexBackward0>)


In [44]:
from src.nets.conv import ConvPottsRateMatrix

In [45]:
potts_matrix = ConvPottsRateMatrix(n_cat=n_cat,kernel_sizes=[3,5,7], num_channels=4).to(device)

In [46]:
out_rates, stay_rate = potts_matrix.get_out_rates(dummy_input, t)

In [47]:
in_rates , stay_rates = potts_matrix.get_in_rates(dummy_input,t)

In [48]:
neighbor_lh_ratios = 2.0*torch.randn(size=(dummy_input.shape[0],n_cat,dummy_input.shape[1], dummy_input.shape[2])).to(device)
jarz_corrector = potts_matrix.get_jarzinsky_corrector(dummy_input, t, neighbor_lh_ratios)

In [49]:
next_step = potts_matrix.sample_next_step(dummy_input, t, 0.1*t)

Show correct implementation of in-rates and out-rates:

In [61]:
idx = 2
idy = 3
old_cat = dummy_input[:,idx,idy]
cp_dummy_input = deepcopy(dummy_input)
new_cat = torch.randint(0,n_cat,size=old_cat.shape)
cp_dummy_input[:,idx,idy] = new_cat
cp_output = potts_net(cp_dummy_input,t)

In [62]:
out_rates,_=potts_matrix.get_in_rates(dummy_input, t)

In [63]:
in_rates,_=potts_matrix.get_out_rates(cp_dummy_input, t)

In [65]:
print("Outflow from x to neighbors: ", out_rates[torch.arange(len(new_cat)),new_cat,idx,idy])
print("Inflow from neighbors to x: ", in_rates[torch.arange(len(new_cat)),old_cat,idx,idy])

Outflow from x to neighbors:  tensor([0.0000, 0.7520, 0.0000, 0.0000], device='cuda:0',
       grad_fn=<IndexBackward0>)
Inflow from neighbors to x:  tensor([0.0000, 0.7520, 0.0000, 0.0000], device='cuda:0',
       grad_fn=<IndexBackward0>)
