In [None]:
import sys
sys.path.append("..")

from pathlib import Path
from typing import Optional, Callable, List, Tuple, Iterable

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset
import torchvision.transforms as VT
import torchvision.transforms.functional as VF
from torchvision.utils import make_grid

from src.datasets import *
from src.util.image import 
#from src.models.cnn import *

torch.autograd.set_detect_anomaly(True)

In [None]:
t = torch.rand(10)
n = torch.rand(10)

In [None]:
print(t)
print(n)
(t >= n).to(t)
for x in range(10):
    print(t.bernoulli())
#-F.threshold(-F.threshold(t, .5, 0.), -0.0001, -1.)

In [None]:
class RBM(nn.Module):
    def __init__(
            self,
            num_in: int,
            num_out: int,
            act_fn: Optional[Callable] = torch.sigmoid,
            bias: bool = True,
    ):
        super().__init__()
        self.num_in = num_in
        self.num_out = num_out
        self.act_fn = act_fn

        self.bias_visible = nn.Parameter(torch.randn(1, self.num_in))
        self.bias_hidden = nn.Parameter(torch.randn(1, self.num_out))
        self.weight = nn.Parameter(torch.randn(self.num_out, self.num_in))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = x.reshape(-1, self.num_in)
        y = F.linear(x, self.weight, self.bias_hidden)
        if self.act_fn is not None:
            y = self.act_fn(y)
        return y

    def visible_to_hidden(self, x: torch.Tensor) -> torch.Tensor:
        x = x.reshape(-1, self.num_in)
        y = F.linear(x, self.weight, self.bias_hidden)
        if self.act_fn is not None:
            y = self.act_fn(y)
        return y.bernoulli()

    def hidden_to_visible(self, x: torch.Tensor) -> torch.Tensor:
        x = x.reshape(-1, self.num_out)
        y = F.linear(x, self.weight.t(), self.bias_visible)
        if self.act_fn is not None:
            y = self.act_fn(y)
        return y.bernoulli()

    def contrastive_divergence(
            self,
            x: torch.Tensor,
            num_steps: int = 1,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        x = x_out = x.reshape(-1, self.num_in)

        state = self.visible_to_hidden(x)
        for step in range(num_steps):
            x_out = self.hidden_to_visible(state)
            if step < num_steps - 1:
                state = self.visible_to_hidden(x_out)
        return x, x_out

    def weight_images(self) -> List[torch.Tensor]:
        return [self.weight]

    def free_energy(self, x):
        v_term = torch.matmul(x, self.bias_visible.t())
        w_x_h = F.linear(x, self.weight, self.bias_hidden)
        h_term = torch.sum(F.softplus(w_x_h), dim=1)
        return torch.mean(-h_term - v_term)

    def train_step(self, input_batch) -> torch.Tensor:
        if isinstance(input_batch, (tuple, list)):
            input_batch = input_batch[0]

        first_state, last_state = self.contrastive_divergence(input_batch)
        loss = (self.free_energy(first_state) - self.free_energy(last_state)) / self.num_out
        #loss = F.mse_loss(first_state, last_state)
        return loss
    
    
model = RBM(10, 5)
x = torch.rand(1, 10)
model.forward(x)
print(x)
#print(model.random_threshold(x))
model.contrastive_divergence(x)
model.free_energy(x)

In [None]:
model = RBM(5, 3)
with torch.no_grad():
    for i in range(20):
        x = torch.rand(1, 5)
        y = model.forward(x)
        _, x2 = model.contrastive_divergence(x)
        print(x, y, x2)


In [None]:
dataset = TensorDataset(torch.load("../datasets/diverse-32x32-aug16.pt"))
VF.to_pil_image(make_grid([
    dataset[0][0], dataset[0][0].bernoulli()
]))

In [None]:
model = RBM(3*32*32, 100)
optimizer = torch.optim.Adadelta(model.parameters(), .1)

def train_step(batch):
    first_state, last_state = model.contrastive_divergence(batch)
    loss = (model.free_energy(first_state) - model.free_energy(last_state)) / model.num_out
    model.zero_grad()
    loss.backward()
    return loss

for batch, in DataLoader(dataset, batch_size=10):
    batch = batch.reshape(-1, 3*32*32)
    loss = train_step(batch)
    print(round(float(loss), 2), end=" ")

In [None]:
t = torch.load("../datasets/diverse-32x32-aug16.pt")
#torch.save(t.clamp(0, 1), "../datasets/diverse-32x32-aug16.pt")