# Hypernetworks experiments

Faisal Qureshi      
faisal.qureshi@ontariotechu.ca

In [None]:
import torch
import torch.nn as nn
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt

In [None]:
image_file1 = '../../data/imagecompression.info/rgb8bit/deer-small.ppm'
image_in = np.array(Image.open(image_file1))
image_file2 = '../../data/imagecompression.info/rgb8bit/deer-small-enhanced.ppm'
image_out = np.array(Image.open(image_file2))

plt.figure(figsize=(10,5))
plt.subplot(1,2,1)
plt.title('Input image')
plt.imshow(image_in)
plt.subplot(1,2,2)
plt.title('Enhanced image')
plt.imshow(image_out);

In [None]:
import einops
def lift(im):
    """lifts input (r,g,b) image to (r2,g2,b2,rg,rb,gb,r,g,b,1)"""
    h, w, _ = im.shape
    x = einops.rearrange(im, 'h w c -> c (h w)')
    x = np.vstack((x**2, x[0,:]*x[1,:], x[0,:]*x[2,:], x[1,:]*x[2,:], x, np.ones(h*w)))
    return x

In [None]:
im_lifted = lift(image_in)
print(im_lifted.shape)

In [None]:
einops.rearrange(im_lifted, 'c (h w) -> h w c', w=512).shape

In [None]:
import torch.utils.data as tdata

class PixData(tdata.Dataset):
    def __init__(self, image_in, image_out):
        """Input and output images 8bit rgb images."""
        self.h, self.w, _ = image_in.shape
        self.image_in = torch.Tensor(lift(image_in)/255.)
        self.image_out = torch.Tensor(rearrange(image_out/255., 'h w c -> c (h w)'))
        
    def __len__(self):
        return self.h*self.w
    
    def __getitem__(self, idx):
        return {
            'data': self.image_in[:, idx],
            'out': self.image_out[:, idx]
        }

In [None]:
dataset = PixData(image_in, image_out)
print(len(dataset))

dataset[0]['data'].shape

In [None]:
class Correct(torch.nn.Module):
    def __init__(self):
        super(Correct, self).__init__()
        self.linear = torch.nn.Linear(10,3,bias=False)
        print(self.linear.weight.shape)
        
    def forward(self, x):
        print(x.shape)
        print(self.linear.weight)
        self.linear.weight = torch.nn.Parameter(torch.zeros((3,10)))
        print(self.linear.weight)
        
        x = self.linear(x)
        print(x.shape)
        return x

In [None]:
c = Correct()
c(rearrange(dataset[0]['data'], 'w -> () w'))

In [None]:
import torchvision.models as torch_models

In [None]:
resnet18 = torch_models.resnet18(pretrained=True)

In [None]:
print(resnet18)

In [None]:
for p in resnet18.parameters():
    p.requires_grad = False

In [None]:
x = dict(resnet18.named_parameters())
x.items()

In [None]:
image_file1 = '../../data/imagecompression.info/rgb8bit/deer-small.ppm'
tmp = np.array(Image.open(image_file1).resize((224,224), Image.BILINEAR))
print(tmp.shape)
print(tmp.min(), tmp.max())

In [None]:
from torchvision import transforms

In [None]:
data_transforms = transforms.Compose([transforms.Resize((224,224)),
                                      transforms.ToTensor(), 
                                      transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])

In [None]:
tmp = data_transforms(Image.open(image_file1))
print(tmp.shape)
print(tmp.min(), tmp.max())

In [None]:
h, w, _ = tmp.shape
print(h, w)

In [None]:
class Test(nn.Module):
    def __init__(self):
        super(Test, self).__init__()
        
        self.l1 = nn.Linear(3,5)
        self.l2 = nn.Linear(5,1)
        
    def forward(self, x):

        return x

In [None]:
t = Test() 

In [None]:
for p in t.parameters():
    print(p)

In [None]:
for n, p in t.named_parameters():
    print(n, p)

In [None]:
x = dict(t.named_parameters())
x['l1.weight'].requires_grad = False
print(x)

In [None]:
class Test2(nn.Module):
    def __init__(self):
        super(Test2, self).__init__()
        
        self.b = Test()
        for p in self.b.parameters():
            p.requires_grad = False
            
        print(self.b.l2)
        self.b.l2 = nn.Linear(5,2)
        
        self.l3 = nn.Linear(4,4)
        
    def forward(self, x):

        return x

In [None]:
t2 = Test2()

In [None]:
for n, p in t2.named_parameters():
    print(n, p)

In [None]:
print([p for p in t2.parameters() if p.requires_grad == True])