In [1]:
import os, shutil, itertools, yaml, kornia, torchvision, sys, copy, math
from functools import partial
import dill as pickle
from importlib import reload
osp = os.path
import numpy as np
import matplotlib.pyplot as plt
from torchvision import transforms
import seaborn as sns
from PIL import Image
import torch
nn = torch.nn
F = nn.functional
import monai.transforms as mtr
from time import time

In [2]:
from inrnet import args as args_module
from inrnet import inn, experiments, optim, util, losses, models
from inrnet.data import dataloader
from inrnet.experiments import depth
import inrnet.inn.functional as inrF
%matplotlib inline
plt.rcParams["figure.figsize"] = (4.0, 3.0)
plt.rcParams['figure.dpi'] = 200

rescale_clip = mtr.ScaleIntensityRangePercentiles(lower=1, upper=99, b_min=0, b_max=255, clip=True, dtype=np.uint8)
rescale_noclip = mtr.ScaleIntensityRangePercentiles(lower=0, upper=100, b_min=0, b_max=255, clip=False, dtype=np.uint8)
rescale_float = mtr.ScaleIntensity()

TMP_DIR=osp.expanduser("~/code/diffcoord/temp")

rsync -av --ignore-existing \
clintonw@peppercorn.csail.mit.edu:/data/vision/polina/users/clintonw/code/placenta/results/unetr_00 Downloads/

## interactive

In [None]:
torch.load("/data/vision/polina/users/clintonw/code/diffcoord/results/inet/results.txt")

In [None]:
CUDA_VISIBLE_DEVICES=3 python train.py -j=inet2 -c=inet

In [None]:
nn.Conv2d(3, 64, kernel_size=(3, 3), stride=(2, 2))(torch.randn(1,3,8,8)).shape

In [None]:
nn.MaxPool2d(kernel_size=(3, 3), stride=(2, 2))(torch.randn(1,3,8,8)).shape

In [None]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

In [None]:
count_parameters(model)

In [33]:
model = torchvision.models.efficientnet_b0(pretrained=True)

In [78]:
torch.cuda.empty_cache()
coords = util.meshgrid_coords(64,64)
#spacing = torch.tensor((4/127,4/127), device="cuda")
values = torch.randn(coords.size(0),6, device=coords.device)

In [None]:
ax = sns.lineplot(x=torch.linspace(128,256,10), y=torch.linspace(.5,.8,10)+torch.randn(10)/8, label="qmc");
sns.lineplot(x=torch.linspace(128,256,10), y=torch.linspace(.1,.8,10)+torch.randn(10)/8, label="grid", ax=ax);
ax.set_xlabel("Number of sample points");
ax.set_ylabel("ImageNet top-1 error");

## Figs

In [None]:
xy = util.generate_quasirandom_sequence(d=2, n=256)
plt.rcParams["figure.figsize"] = (2,1)
fig,ax = plt.subplots()
ax.scatter(xy[:,0], xy[:,1], s=2, c='k')
ax.plot();

In [None]:
kernel = model.features[0][0].weight[3,1].detach()
plt.rcParams["figure.figsize"] = (.5,.5)
fig,ax = plt.subplots()
print(kernel.min().item(), kernel.max().item())
ax.imshow(kernel, vmin=-2.9, vmax=2.9);
plt.axis('off');

In [100]:
from scipy.interpolate import RectBivariateSpline as Spline2D
K = [2.5,2.5]
h,w=3,3
bbox = (-K[0]/2, K[0]/2, -K[1]/2, K[1]/2)
x,y = (np.linspace(bbox[0]/h*(h-1), bbox[1]/h*(h-1), h),
       np.linspace(bbox[2]/w*(w-1), bbox[3]/w*(w-1), w))
bs = Spline2D(x,y, kernel, bbox=bbox, kx=2,ky=2, s=0)
tx,ty,c = [torch.tensor(z).float() for z in bs.tck]
c = c.reshape(h,w)

In [106]:
H = 50
xy = util.meshgrid_coords(H,H).cpu()
w_oi = []
X = xy[:,0].unsqueeze(1)
Y = xy[:,1].unsqueeze(1)
px = py = 2

values, kx = (tx<=X).min(dim=-1)
values, ky = (ty<=Y).min(dim=-1)
kx -= 1
ky -= 1
kx[values] = tx.size(-1)-px-2
ky[values] = ty.size(-1)-py-2

Ctrl = c.view(1, *c.shape[-2:])
for z in range(X.size(0)):
    D = Ctrl[:, kx[z]-px : kx[z]+1, ky[z]-py : ky[z]+1].clone()

    for r in range(1, px + 1):
        try:
            alphax = (X[z,0] - tx[kx[z]-px+1:kx[z]+1]) / (
                tx[2+kx[z]-r:2+kx[z]-r+px] - tx[kx[z]-px+1:kx[z]+1])
        except RuntimeError:
            print("input off the grid")
        for j in range(px, r - 1, -1):
            D[:,j] = (1-alphax[j-1]) * D[:,j-1] + alphax[j-1] * D[:,j].clone()

    for r in range(1, py + 1):
        alphay = (Y[z,0] - ty[ky[z]-py+1:ky[z]+1]) / (
            ty[2+ky[z]-r:2+ky[z]-r+py] - ty[ky[z]-py+1:ky[z]+1])
        for j in range(py, r-1, -1):
            D[:,px,j] = (1-alphay[j-1]) * D[:,px,j-1].clone() + alphay[j-1] * D[:,px,j].clone()

    w_oi.append(D[:,px,py])

w = torch.stack(w_oi).view(xy.size(0))

In [107]:
k = w.reshape(H,H)

In [None]:
fig,ax = plt.subplots()
print(k.min().item(), k.max().item())
ax.imshow(k, vmin=-2.9, vmax=2.9);
plt.axis('off');

## slurm

In [30]:
coco_len = 118287
kitti_len = 80896
horse_len = 1067
zebra_len = 1334
inet_len = 50000
in_tr_len = 50000
places_len = 0

In [None]:
for ix in range(0,places_len,800):
    if not osp.exists(f"/data/vision/polina/scratch/clintonw/datasets/inrnet/places/siren_{ix+63}.pt"):
        print("sh fit_inr.sh", f"place{ix//800}", "fit_places", ix)

In [None]:
for ix in range(0,in_tr_len,800):
    if not osp.exists(f"/data/vision/polina/scratch/clintonw/datasets/inrnet/inet1k_train/siren_{ix+63}.pt"):
        print("sh fit_inr.sh", f"inet{ix//800}", "fit_in_tr", ix)

In [None]:
for ix in range(0,inet_len,800):
    if not osp.exists(f"/data/vision/polina/scratch/clintonw/datasets/inrnet/imagenet1k/siren_{ix+63}.pt"):
        print("sh fit_inr.sh", f"inet{ix//800}", "fit_inet", ix)

In [None]:
for ix in range(0,coco_len,1600):
    if not osp.exists(f"/data/vision/polina/scratch/clintonw/datasets/inrnet/coco/siren_{ix+63}.pt"):
        print("sh fit_inr.sh", f"coco{ix//1600}", "fit_coco", ix)

In [None]:
for ix in range(0,kitti_len,800):
    if not osp.exists(f"/data/vision/polina/scratch/clintonw/datasets/inrnet/kitti/siren_{ix+63}.pt"):
        print("sh fit_inr.sh", f"skit{ix//800}", "fit_kitti", ix)

In [None]:
for ix in range(0,zebra_len,128):
    if not osp.exists(f"/data/vision/polina/scratch/clintonw/datasets/inrnet/zebra/siren_{ix+15}.pt"):
        print("sh fit_inr.sh", f"zebra{ix//128}", "fit_zebra", ix)

In [None]:
for ix in range(0,horse_len,128):
    if not osp.exists(f"/data/vision/polina/scratch/clintonw/datasets/inrnet/horse/siren_{ix+15}.pt"):
        print("sh fit_inr.sh", f"horse{ix//128}", "fit_horse", ix)

In [None]:
img = plt.imread("/data/vision/polina/scratch/clintonw/datasets/inrnet/horse2zebra/testA/n02381460_9260.jpg")

In [None]:
"/data/vision/polina/scratch/clintonw/datasets/coco/annotations/panoptic_train2017"

In [None]:
osp.exists(f"/data/vision/polina/users/clintonw/code/diffcoord/temp/kitti/siren_{ix+63}.pt")

In [60]:
paths = sorted(util.glob2("/data/vision/polina/scratch/clintonw/datasets/inrnet/kitti/siren_*.pt"))

In [None]:
## search python strings
import os
osp = os.path
root = osp.expanduser("~/code/placenta/placenta")
query_string = 'load_img'
for folder, subfolders, files in os.walk(root):
    for file in files:
        if file.endswith(".py"):
            path = osp.join(folder, file)
            with open(path, 'r') as f:
                if query_string in f.read():
                    print(path)