In [5]:
import torch
import torch.nn.functional as F
import numpy as np
from torch import nn
from tqdm import tqdm
import torch.utils.data
import matplotlib.pyplot as plt
from sklearn.datasets import make_swiss_roll
from sklearn.metrics import mean_squared_error
from time import time
import pandas as pd

In [6]:
DEVICE = 'mps'

In [7]:
def sample_batch(size):
    x, _ = make_swiss_roll(size)
    x = x[:, [2, 0]] / 10.0 * np.array([1, -1])
    return x[:, 0].reshape((1, size))

In [8]:
def reward_function(features, weights, bias=None):
    # Calculate the linear combination
    logits = torch.matmul(features, weights)

    # If bias is provided, add it to the logits
    if bias is not None:
        logits += bias

    # Apply the sigmoid function to get the probabilities
    probabilities = torch.sigmoid(logits)

    return probabilities

In [9]:
model_name = "15"
path_to_weights = (
    "/Users/hazimiasad/Documents/Work/megan/data/collection/Study1/sub-"
    + model_name
    + "/pattern/dc_weights.csv"
)
weights = torch.from_numpy(pd.read_csv(path_to_weights, header=None).values.T).to(
    DEVICE, dtype=torch.float32
)
state_size = len(weights)

In [10]:
def train_rl(model, optimizer, nb_epochs=150_000, batch_size=6_000):
    training_loss = []
    rewards = []
    for epoch in tqdm(range(nb_epochs)):
        x0 = torch.from_numpy(sample_batch(batch_size)).float().to(device)
        t = 40
        mu_posterior, sigma_posterior, x = model.forward_process(x0, t)
        log_probs = []
        for t in range(40, 0, -1):
            # mu, sigma, x = model.reverse(x, t)
            x, log_prob, _, _ = select_action(model.model, x, t)
            log_probs.append(log_prob)


        # print(torch.stack(log_probs).shape)

        # reward = swiss_roll_reward(x.cpu().detach().numpy(), x0.cpu().detach().numpy())
        reward = calculate_probability(x, weights)
        rewards.append(reward)
        # print(reward, end=', ')
        reward = torch.tensor(reward, dtype=torch.float32).to(device)
        if torch.isnan(reward).any():
            print(f"NaN detected in reward at epoch {epoch}")
            break

        log_probs_stack_sum = torch.stack(log_probs).sum()
        loss = -reward * log_probs_stack_sum
        if torch.isnan(loss).any():
            print(f"NaN detected in loss at epoch {epoch}")
            break

        optimizer.zero_grad()
        loss.backward()

        # Check gradients
        for name, param in model.named_parameters():
            if param.grad is not None and torch.isnan(param.grad).any():
                print(f"NaN detected in gradients for {name} at epoch {epoch}")
                break

        # torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()
        training_loss.append(loss.item())

        # Check parameters
        for name, param in model.named_parameters():
            if torch.isnan(param).any():
                print(f"NaN detected in parameters for {name} at epoch {epoch}")
                break

    return training_loss, rewards

In [11]:
from classes.policy_network import PolicyNetwork

In [12]:
policy_net = PolicyNetwork(state_size, state_size, device=DEVICE).to(DEVICE)
# model = DiffusionModel(policy_net)
# optimizer = torch.optim.Adam(policy_net.parameters(), lr=1e-4)