In [1]:
import torch
import torch.nn.functional as F
import torch.utils.data

from s2cnn import SO3Convolution
from s2cnn import S2Convolution
from s2cnn import so3_integrate
from s2cnn import so3_near_identity_grid
from s2cnn import s2_near_identity_grid

import time
import random

random.seed()

In [2]:
def input0():
    return torch.zeros(1,60,60)

def inputrand():
    return torch.rand(1,60,60)

def visualize_tensor_layer(t): # expecting 2D tensor of any size
    assert len(t.shape) == 2
    for i in range(t.shape[0]):
        for j in range(t.shape[1]):
            print('{}'.format('1' if t[i][j].item() > 0. else '0'), end='')
        print('')

def add_light_fuzz(t, fuzz_factor=0.02): # expecting N * 60 * 60, add 2% fuzz
    assert fuzz_factor >= 0.
    assert fuzz_factor < 1.
    out = t.clone()
    for i in range(60):
        for j in range(60):
            if random.random() < fuzz_factor: # 5%ish
                out[:,i,j] = 1.
    return out
        
# could pass in all 0s, or add this feature to existing tensor
# applies to all layers
def make_training_feature1(t): # expecting N * 60 * 60
    for i in range(5, 35):
        for j in range(5, 35):
            x = i - 20
            y = j - 20
            r2 = x * x + y * y 
            if r2 > 64 and r2 < 144:
                t[:,i,j] = 1.
    return t

def make_training_feature2(t): # expecting N * 60 * 60
    for i in range(30, 60):
        for j in range(30, 60):
            x = i - 30
            y = 60 - j # 0,0 @ ctr btm
            if abs(x - y) <= 2:
                t[:,i,j] = 1.
    return t

def make_training_feature3(t): # expecting N * 60 * 60
    for i in range(30, 60):
        for j in range(0, 30):
            if abs(45 - i) <= 2 or abs(15 - j) <= 2:
                t[:,i,j] = 1.
    return t

def make_training_feature4(t):
    for i in range(0, 30):
        for j in range(30, 60):
            x = i
            y = 0 - j
            if abs(x - y) <= 2 or ((30 - x) - y) <= 2:
                t[:,i,j] = 1.
    return t

In [3]:
def reflect_vert(t): # expects 2d tensor
    out = t.clone()
    assert len(out.shape) == 2
    
    w = out.shape[0]
    h = out.shape[1]
    buffer = torch.zeros(h)
    for i in range(w // 2):
        buffer = out[i, :].clone() # i think i have to clone else broadcasting turns weird?
        out[i, :] = out[w - i - 1, :].clone()
        out[w - i - 1, :] = buffer
    return out

def reflect_horiz(t):
    return reflect_vert(t.transpose(0,1)).transpose(0,1)

In [4]:
# vis each of N kernels separately
# each out kernel is a 20^3 tensor in default config
# changed to 5 * 10^3 for this run
def visualize_s2conv_out(t):
    assert len(t.shape) == 4
    for kernel in range(t.shape[0]):
        print('KERNEL {}'.format(kernel))
        outblock = []
        for i in range(t.shape[1]): # row * column * depth = grid * depth = block
            outgrid = []
            for j in range(t.shape[2]): # row * column = grid
                outrow = ''
                for k in range(t.shape[3]): # row
                    outrow += '1' if t[kernel,i,j,k] > 0. else '0'
                outgrid.append(outrow)
            outblock.append(outgrid)
            
        visblock = []
        for i in range(len(outblock[0])):
            visblock.append([])
        for i in range(len(outblock)): # pull each grid from block
            grid = outblock[i]
            for j in range(len(grid)): # pull each row from grid
                visblock[j].append(grid[j])
        for i in range(len(visblock)):
            print(' '.join(visblock[i]))
                    
def show_conv_layer_for(t, model): # expects 1 * 60 * 60 tensor
    torch.unsqueeze(t, 0) # batch
    _, _, t_conv = model(t)
    visualize_s2conv_out(torch.squeeze(t_conv, 0))

In [5]:
class DummyDataset(torch.utils.data.Dataset):
    def __init__(self, n_inst=1000, real_only=False, fuzz_factor=None): # make 1000 instances of my training instances plus some bllshit
        tf1 = make_training_feature1(input0()) # depends on global fns oops
        tf2 = make_training_feature2(input0())
        tf3 = make_training_feature3(input0())
        tf4 = make_training_feature4(input0())
        REAL_FEATURES = 4
        MAX_FEATURES = 5
        self.instances = []
        for _ in range(n_inst):
            r = random.randrange(REAL_FEATURES if real_only else MAX_FEATURES)
            new_instance = None
            if r == 0:
                new_instance = tf1 # circle in upper left
            elif r == 1:
                new_instance = tf2 # SW-NE diag line in lower right
            elif r == 2:
                new_instance = tf3 # cross upper right
            elif r == 3:
                new_instance = tf4 # diag cross lower right
            else:
                new_instance = inputrand() # bullshit
            
            assert not (new_instance is None)
            if (not (fuzz_factor is None)) and r < REAL_FEATURES:
                new_instance = add_light_fuzz(new_instance, fuzz_factor=fuzz_factor)
            self.instances.append((new_instance, r))
    
    def __len__(self):
        return len(self.instances)
    
    def __getitem__(self, idx):
        return self.instances[idx] # tensor, class idx

training_dataset = DummyDataset()
training_loader = torch.utils.data.DataLoader(training_dataset, batch_size=4, shuffle=True)
testing_dataset = DummyDataset(n_inst = 100, real_only=True, fuzz_factor=0.01)
testing_loader = torch.utils.data.DataLoader(testing_dataset, batch_size=1, shuffle=False)

In [6]:
class S2TestModel(torch.nn.Module):
    def __init__(self):
        super(S2TestModel, self).__init__()
        self.conv1 = S2Convolution(nfeature_in=1, nfeature_out=10, b_in=30, b_out=5, grid=s2_near_identity_grid())
        self.fc1 = torch.nn.Linear(10, 20) # i am just fucking around with this architecture - will it learn?
        self.fc2 = torch.nn.Linear(20, 5)
    
    def forward(self, x):
        # print(x.shape)
        x = self.conv1(x)
        conv_out = x
        # print(x.shape)
        x = F.relu(x)
        x = so3_integrate(x)
        int_out = x
        # print(x.shape)
        # activation?
        x = self.fc1(x)
        # print(x.shape)
        x = F.relu(x)
        x = self.fc2(x)
        # print(x.shape)
        return x, int_out, conv_out

In [7]:
def tlog(s):
    print('{}: {}'.format(time.asctime(), s))

In [8]:
N_EPOCHS = 100
BATCH_SIZE = 4
FUZZ_FACTOR = 0.01

In [9]:
training_dataset = DummyDataset(fuzz_factor=FUZZ_FACTOR)
training_loader = torch.utils.data.DataLoader(training_dataset, batch_size=BATCH_SIZE, shuffle=True)
testing_dataset = DummyDataset(n_inst=100, real_only=True, fuzz_factor=FUZZ_FACTOR)
testing_loader = torch.utils.data.DataLoader(testing_dataset, batch_size=1, shuffle=False)

model = S2TestModel()

loss_fn = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.5)

In [10]:
for epoch in range(N_EPOCHS):
    tlog('EPOCH {} of {}'.format(epoch + 1, N_EPOCHS))
    
    # train
    running_loss = 0.
    for i, (images, labels) in enumerate(training_loader):
        model.train()
        
        optimizer.zero_grad()
        guesses, _, _ = model(images) # discard intermediate outputs
        loss = loss_fn(guesses, labels)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        if (i + 1) % 50 == 0:
            tlog('iter {}/{}  loss {}  running loss {}'.format(i + 1, len(training_dataset) // BATCH_SIZE, loss.item(), running_loss))
    
    # validate
    total = 0
    correct = 0
    for i, (images, labels) in enumerate(testing_loader):
        model.eval()
        with torch.no_grad():
            guesses, _, _ = model(images)
            _, predictions = torch.max(guesses, 1)
            total += labels.size(0) # add batch size to total
            correct += (predictions == labels).long().sum().item()
    tlog('Test accuracy for epoch {}: {}'.format(epoch, float(correct) / float(total)))

Thu Apr 18 17:59:34 2019: EPOCH 1 of 100
load 2.pkl.gz... done
load 2.pkl.gz... done
load 4.pkl.gz... done
Thu Apr 18 17:59:34 2019: iter 50/250  loss 1.7012929916381836  running loss 72.44314575195312
Thu Apr 18 17:59:35 2019: iter 100/250  loss 1.2336326837539673  running loss 138.89939546585083
Thu Apr 18 17:59:35 2019: iter 150/250  loss 1.244809627532959  running loss 201.81333076953888
Thu Apr 18 17:59:36 2019: iter 200/250  loss 1.3569700717926025  running loss 263.418460637331
Thu Apr 18 17:59:36 2019: iter 250/250  loss 0.780707597732544  running loss 319.1817424297333
Thu Apr 18 17:59:36 2019: Test accuracy for epoch 0: 0.31
Thu Apr 18 17:59:36 2019: EPOCH 2 of 100
Thu Apr 18 17:59:37 2019: iter 50/250  loss 0.6877192854881287  running loss 54.09107583761215
Thu Apr 18 17:59:37 2019: iter 100/250  loss 0.6392704248428345  running loss 103.9368489086628
Thu Apr 18 17:59:38 2019: iter 150/250  loss 0.6217330098152161  running loss 155.21555760502815
Thu Apr 18 17:59:38 2019: it

Thu Apr 18 18:00:17 2019: iter 50/250  loss 0.04784655570983887  running loss 4.142983675003052
Thu Apr 18 18:00:18 2019: iter 100/250  loss 0.06721460819244385  running loss 8.751645803451538
Thu Apr 18 18:00:18 2019: iter 150/250  loss 0.4835399389266968  running loss 12.939526557922363
Thu Apr 18 18:00:19 2019: iter 200/250  loss 0.07257747650146484  running loss 17.086385369300842
Thu Apr 18 18:00:19 2019: iter 250/250  loss 0.04731476306915283  running loss 20.94782519340515
Thu Apr 18 18:00:20 2019: Test accuracy for epoch 14: 1.0
Thu Apr 18 18:00:20 2019: EPOCH 16 of 100
Thu Apr 18 18:00:20 2019: iter 50/250  loss 0.06874048709869385  running loss 4.195220112800598
Thu Apr 18 18:00:21 2019: iter 100/250  loss 0.14701485633850098  running loss 8.043732523918152
Thu Apr 18 18:00:21 2019: iter 150/250  loss 0.011994123458862305  running loss 11.223838329315186
Thu Apr 18 18:00:22 2019: iter 200/250  loss 0.13698136806488037  running loss 14.00735354423523
Thu Apr 18 18:00:22 2019: 

Thu Apr 18 18:01:01 2019: iter 50/250  loss 0.013907551765441895  running loss 0.5769209861755371
Thu Apr 18 18:01:02 2019: iter 100/250  loss 0.010252237319946289  running loss 1.1634210348129272
Thu Apr 18 18:01:02 2019: iter 150/250  loss 0.006017804145812988  running loss 1.634366512298584
Thu Apr 18 18:01:03 2019: iter 200/250  loss 0.009217143058776855  running loss 2.263617515563965
Thu Apr 18 18:01:03 2019: iter 250/250  loss 0.02204275131225586  running loss 2.7139729261398315
Thu Apr 18 18:01:04 2019: Test accuracy for epoch 28: 1.0
Thu Apr 18 18:01:04 2019: EPOCH 30 of 100
Thu Apr 18 18:01:04 2019: iter 50/250  loss 0.002591252326965332  running loss 0.5044167041778564
Thu Apr 18 18:01:05 2019: iter 100/250  loss 0.019701480865478516  running loss 1.3006447553634644
Thu Apr 18 18:01:05 2019: iter 150/250  loss 0.004873394966125488  running loss 1.7449367046356201
Thu Apr 18 18:01:06 2019: iter 200/250  loss 0.005588054656982422  running loss 2.292712450027466
Thu Apr 18 18:0

Thu Apr 18 18:01:45 2019: Test accuracy for epoch 41: 1.0
Thu Apr 18 18:01:45 2019: EPOCH 43 of 100
Thu Apr 18 18:01:45 2019: iter 50/250  loss 0.0043032169342041016  running loss 0.2075110673904419
Thu Apr 18 18:01:46 2019: iter 100/250  loss 0.0008511543273925781  running loss 0.3937809467315674
Thu Apr 18 18:01:46 2019: iter 150/250  loss 0.0012335777282714844  running loss 0.5967915058135986
Thu Apr 18 18:01:47 2019: iter 200/250  loss 0.022738337516784668  running loss 0.8545222282409668
Thu Apr 18 18:01:48 2019: iter 250/250  loss 0.00590062141418457  running loss 1.250901699066162
Thu Apr 18 18:01:48 2019: Test accuracy for epoch 42: 1.0
Thu Apr 18 18:01:48 2019: EPOCH 44 of 100
Thu Apr 18 18:01:48 2019: iter 50/250  loss 0.004730582237243652  running loss 0.22783362865447998
Thu Apr 18 18:01:49 2019: iter 100/250  loss 0.012341856956481934  running loss 0.45401835441589355
Thu Apr 18 18:01:49 2019: iter 150/250  loss 0.016266942024230957  running loss 0.6601917743682861
Thu Apr

Thu Apr 18 18:02:26 2019: iter 250/250  loss 0.0008101463317871094  running loss 0.8908646106719971
Thu Apr 18 18:02:27 2019: Test accuracy for epoch 55: 1.0
Thu Apr 18 18:02:27 2019: EPOCH 57 of 100
Thu Apr 18 18:02:27 2019: iter 50/250  loss 0.0011904239654541016  running loss 0.12647223472595215
Thu Apr 18 18:02:28 2019: iter 100/250  loss 0.0005021095275878906  running loss 0.23830246925354004
Thu Apr 18 18:02:28 2019: iter 150/250  loss 0.0014193058013916016  running loss 0.5791858434677124
Thu Apr 18 18:02:29 2019: iter 200/250  loss 0.0011472702026367188  running loss 0.7084864377975464
Thu Apr 18 18:02:30 2019: iter 250/250  loss 0.001474618911743164  running loss 0.8450053930282593
Thu Apr 18 18:02:30 2019: Test accuracy for epoch 56: 1.0
Thu Apr 18 18:02:30 2019: EPOCH 58 of 100
Thu Apr 18 18:02:31 2019: iter 50/250  loss 0.001054525375366211  running loss 0.11716365814208984
Thu Apr 18 18:02:31 2019: iter 100/250  loss 0.00411534309387207  running loss 0.2666299343109131
Thu

Thu Apr 18 18:03:09 2019: iter 200/250  loss 0.0008141994476318359  running loss 0.35329437255859375
Thu Apr 18 18:03:10 2019: iter 250/250  loss 0.0008668899536132812  running loss 0.6373945474624634
Thu Apr 18 18:03:10 2019: Test accuracy for epoch 69: 1.0
Thu Apr 18 18:03:10 2019: EPOCH 71 of 100
Thu Apr 18 18:03:11 2019: iter 50/250  loss 0.0003750324249267578  running loss 0.29145145416259766
Thu Apr 18 18:03:11 2019: iter 100/250  loss 0.001971006393432617  running loss 0.3854668140411377
Thu Apr 18 18:03:12 2019: iter 150/250  loss 0.0021085739135742188  running loss 0.49733424186706543
Thu Apr 18 18:03:13 2019: iter 200/250  loss 0.0006787776947021484  running loss 0.5790929794311523
Thu Apr 18 18:03:13 2019: iter 250/250  loss 0.0015070438385009766  running loss 0.6510272026062012
Thu Apr 18 18:03:13 2019: Test accuracy for epoch 70: 1.0
Thu Apr 18 18:03:13 2019: EPOCH 72 of 100
Thu Apr 18 18:03:14 2019: iter 50/250  loss 0.002580881118774414  running loss 0.08542013168334961


Thu Apr 18 18:03:54 2019: iter 150/250  loss 0.00220489501953125  running loss 0.3856689929962158
Thu Apr 18 18:03:55 2019: iter 200/250  loss 0.002689361572265625  running loss 0.4281938076019287
Thu Apr 18 18:03:55 2019: iter 250/250  loss 0.0003113746643066406  running loss 0.5001001358032227
Thu Apr 18 18:03:55 2019: Test accuracy for epoch 83: 1.0
Thu Apr 18 18:03:55 2019: EPOCH 85 of 100
Thu Apr 18 18:03:56 2019: iter 50/250  loss 0.0008363723754882812  running loss 0.0663151741027832
Thu Apr 18 18:03:57 2019: iter 100/250  loss 0.0037560462951660156  running loss 0.13071632385253906
Thu Apr 18 18:03:57 2019: iter 150/250  loss 0.00048732757568359375  running loss 0.19929242134094238
Thu Apr 18 18:03:58 2019: iter 200/250  loss 0.00016021728515625  running loss 0.2554435729980469
Thu Apr 18 18:03:58 2019: iter 250/250  loss 0.0005936622619628906  running loss 0.47195863723754883
Thu Apr 18 18:03:59 2019: Test accuracy for epoch 84: 1.0
Thu Apr 18 18:03:59 2019: EPOCH 86 of 100
Th

Thu Apr 18 18:04:38 2019: iter 100/250  loss 0.005831718444824219  running loss 0.29160046577453613
Thu Apr 18 18:04:38 2019: iter 150/250  loss 0.0035860538482666016  running loss 0.35128211975097656
Thu Apr 18 18:04:39 2019: iter 200/250  loss 0.00045609474182128906  running loss 0.3998434543609619
Thu Apr 18 18:04:39 2019: iter 250/250  loss 0.0009102821350097656  running loss 0.44585561752319336
Thu Apr 18 18:04:40 2019: Test accuracy for epoch 97: 1.0
Thu Apr 18 18:04:40 2019: EPOCH 99 of 100
Thu Apr 18 18:04:40 2019: iter 50/250  loss 0.0001316070556640625  running loss 0.04654240608215332
Thu Apr 18 18:04:41 2019: iter 100/250  loss 0.0013399124145507812  running loss 0.0926353931427002
Thu Apr 18 18:04:41 2019: iter 150/250  loss 0.00042819976806640625  running loss 0.13300108909606934
Thu Apr 18 18:04:42 2019: iter 200/250  loss 0.0003333091735839844  running loss 0.19655680656433105
Thu Apr 18 18:04:42 2019: iter 250/250  loss 0.0012197494506835938  running loss 0.43508315086

In [11]:
tf1 = make_training_feature1(input0()) # depends on global fns oops
tf2 = make_training_feature2(input0())
tf3 = make_training_feature3(input0())
tf4 = make_training_feature4(input0())
tfs = [tf1, tf2, tf3, tf4]

In [12]:
for i, tf_raw in enumerate(tfs):
    tf = torch.squeeze(tf_raw)
    tfv = reflect_vert(tf)
    tfh = reflect_horiz(tf)
    tfvh = reflect_horiz(tfv)
    for t_in in [tf, tfv, tfh, tfvh]:
        t_in = torch.squeeze(add_light_fuzz(torch.unsqueeze(t_in, 0), fuzz_factor=FUZZ_FACTOR))
        guess, _, _ = model(torch.unsqueeze(torch.unsqueeze(t_in, 0), 0))
        _, pred = torch.max(guess, 1)
        print('got {} for {}'.format(pred.item(), i))

got 0 for 0
got 0 for 0
got 0 for 0
got 0 for 0
got 1 for 1
got 1 for 1
got 1 for 1
got 1 for 1
got 2 for 2
got 2 for 2
got 2 for 2
got 2 for 2
got 3 for 3
got 3 for 3
got 3 for 3
got 3 for 3


In [None]:
t_out, t_int, t_conv = model(torch.unsqueeze(input0(), 0))
print('output shape {}'.format(t_out.shape))
print('post-integration shape {}'.format(t_int.shape))
print('post-conv shape {}'.format(t_conv.shape))
print(torch.squeeze(t_out, 0))
print(torch.squeeze(t_int, 0))
visualize_s2conv_out(torch.squeeze(t_conv, 0))

In [None]:
show_conv_layer_for(tf1, model)