Part 1: Fit a Neural Field to a 2D Image

In [None]:
from PIL import Image
import mediapy as media
from pprint import pprint
from tqdm import tqdm

import torch
import torchvision.transforms.functional as TF
import torchvision.transforms as transforms
from diffusers import DiffusionPipeline
from transformers import T5EncoderModel

# For downloading web images
import requests
from io import BytesIO
import numpy as np
import matplotlib.pyplot as plt
import os
import torch.nn as nn
import tqdm
 

device = 'cuda'

In [3]:
import torch
import torch.nn as nn
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt

class MLP(nn.Module):
    def __init__(self, input_size=2, L=10, device='cuda'):
        super(MLP, self).__init__()

        self.L = L
        self.device = device

        
        self.mlp = nn.Sequential(
            nn.Linear((L * 2 + 1) * 2, 256),
            nn.ReLU(),
            nn.Linear(256, 256),
            nn.ReLU(),
            nn.Linear(256, 256),
            nn.ReLU(),
            nn.Linear(256, 3),
            nn.Sigmoid(),
        ).to(device)

    def forward(self, x):
        x = x.to(self.device)

        pe = [x]  

        for i in range(self.L):
            freq = 2 ** i * np.pi * x  # frequency: 2^i * pi * x
            pe.append(torch.sin(freq.clone().detach().float()))  
            pe.append(torch.cos(freq.clone().detach().float()))  
        
        pe = torch.cat(pe, dim=-1)  
        return self.mlp(pe)

def compute_psnr(pred, target):
    mse = torch.mean((pred - target) ** 2)
    if mse == 0:
        return 100  
    return 10 * torch.log10(1.0 / mse)

def rand_sample(im, N):
    x, y, z = im.shape
    points = []
    rand_x = np.random.randint(0, x, N)
    rand_y = np.random.randint(0, y, N)
    colors = []
    
    for i in range(N):
        points.append([rand_y[i] / y, rand_x[i] / x])  
        c = im[rand_x[i], rand_y[i]]
        c = c / 255  
        colors.append(c)
    # print("points",points)
    return np.array(points), np.array(colors)

def create_image(ima, model, L):
    h, w, _ = ima.shape
    h, w = h // 3, w // 3  
    # print(h,w)
    x = torch.linspace(0, w-1, w).repeat(h, 1) / w 
    # print(x.shape)
    y = torch.linspace(0, h-1, h).repeat(w, 1).transpose(0, 1) / h  
    # print(y.shape)
    all_coords = torch.stack([x, y], dim=-1).view(-1, 2)
    # print(all_coords.shape)
    # print(all_coords)

    
    with torch.no_grad():
        predicted_pixels = model(all_coords)
    
    predicted_pixels_cpu = predicted_pixels.cpu()

    predicted_image = predicted_pixels_cpu.reshape(h, w, 3).numpy()

    predicted_image = np.clip(predicted_image * 255, 0, 255).astype(np.uint8)
    
    return predicted_image


def impl(im, model, N, L, iterations, lr=0.01):
    im = np.array(im)
    loss_func = nn.MSELoss()
    optimizer = torch.optim.Adam(model.parameters(), lr)

    loss_list = []
    psnr_list = []
    predicted_images = []
    model.train()
    
    for i in range(iterations):
        points, colors = rand_sample(im, N)
        # print("p",points.shape)

        loss_t = 0
        psnr_t = 0
        
        for j in range(N):
            p = points[j]
            
            c = colors[j]
            
            p_tensor = torch.tensor(p, dtype=torch.float32).to(model.device)
            c_tensor = torch.tensor(c, dtype=torch.float32).to(model.device)
            
            optimizer.zero_grad()
            pred = model(p_tensor)
            loss = loss_func(pred, c_tensor)
            loss_t += loss.item()
            psnr = compute_psnr(pred, c_tensor)
            psnr_t += psnr

            loss.backward()
            optimizer.step()
        
        loss_list.append(loss_t / N)
        psnr_list.append(psnr_t / N)
        print(f"Iteration {i}, Loss: {loss_t / N}, PSNR: {psnr_t / N}")
        
        if i % 100 == 0:
            model.eval()
            image_create = create_image(im, model, L)
            predicted_images.append(image_create)
            plt.imshow(image_create)
            plt.show()
            model.train()

    return model, loss_list, psnr_list, predicted_images


In [None]:
# Example usage with an image
fox = Image.open('fox.jpg')
m = MLP(2, 10, device='cuda')
print(m)
model_a1, loss_list_a, psnr_list_a, pred_im_a = impl(fox, m, N=100, L=10, iterations=20000, lr=0.01)
im = create_image(np.array(fox), model_a1, 10)
plt.imshow(im)
plt.show()

In [None]:
fox = Image.open('fox.jpg')
mb = MLP(2, 10, device='cuda')
# print(m)
model_b, loss_list_b, psnr_list_b, pred_im_b = impl(fox, mb, N=1000, L=5, iterations=1000, lr=0.01)
im = create_image(np.array(fox), model_b, 5)
plt.imshow(im)
plt.show()

In [None]:
fox = Image.open('fox.jpg')
mc = MLP(2, 10, device='cuda')
# print(m)
model_c, loss_list_c, psnr_list_c, pred_im_c = impl(fox, mc, N=10000, L=10, iterations=100, lr=0.1)
im = create_image(np.array(fox), model_c, 10)
plt.imshow(im)
plt.show()

In [None]:
fox = Image.open('grass.jpg')
mcg = MLP(2, 10, device='cuda')
# print(m)
model_cg, loss_list_cg, psnr_list_cg, pred_im_cg = impl(fox, mcg, N=1000, L=10, iterations=100, lr=0.01)
im = create_image(np.array(fox), model_cg, 10)
plt.imshow(im)
plt.show()

In [None]:
fox = Image.open('fox.jpg')
# fox_np = np.array(fox)
# a, b, c = fox_np.shape
# print(a, b, c)
# fox = np.resize(fox_np,(a // 5, b // 5, c))
mcgaf = MLP(2, 10, device='cuda')
# print(m)
model_cga, loss_list_cga, psnr_list_cga, pred_im_cga = impl(fox, mcgaf, N=10000, L=10, iterations=3000, lr=0.01)
im = create_image(np.array(fox), model_cga, 5)
plt.imshow(im)
plt.show()

In [None]:
fox = Image.open('grass.jpg')
fox_np = np.array(fox)
a, b, c = fox_np.shape
print(a, b, c)
fox = np.resize(fox_np,(a // 5, b // 5, c))
mcga = MLP(2, 10, device='cuda')
# print(m)
model_cga, loss_list_cga, psnr_list_cga, pred_im_cga = impl(fox, mcga, N=10000, L=20, iterations=100, lr=0.01)
im = create_image(np.array(fox_np), model_cga, 5)
plt.imshow(im)
plt.show()

In [None]:
fox = Image.open('grass.jpg')
fox_np = np.array(fox)
a, b, c = fox_np.shape
print(a, b, c)
fox_np = np.resize(fox_np,(a // 5, b // 5, c))
print(fox_np.shape)
mcgb = MLP(2, 10, device='cuda')
# print(m)
model_cgb, loss_list_cgb, psnr_list_cgb, pred_im_cgb = impl(fox_np, mcgb, N=1000, L=10, iterations=1000, lr=0.005)
im = create_image(fox_np, model_cgb, 10)
plt.imshow(im)

In [None]:
fox=Image.open('fox.jpg')
model,loss_list,psnr_list=impl(fox,10000,10,2000,0.01)




In [None]:
fox=Image.open('fox.jpg')
m=MLP(2,10)
print(m)
model_a,loss_list_a,psnr_list_a,pred_im_a=impl(fox,m,1000,10,100,0.01)
im=create_image(np.array(fox),model_a,10)
