In [1]:
from models import *
from qhoptim.pyt import QHAdam
import pytorch_lightning as pl

In [2]:
import build
loader = build.dataloader(data='dataset_v6', sample_length=4096, part=0.001, batch_size=1, shuffle=True)
print(len(loader))

1


In [3]:
x, y = next(iter(loader))
print(x.shape, y.shape)

torch.Size([1, 256, 4096]) torch.Size([1])


In [4]:
print(x.min(), x.max())

tensor(0.) tensor(1.)


In [5]:
from models import *
from qhoptim.pyt import QHAdam
import pytorch_lightning as pl

import build
loader = build.dataloader(data='dataset_v6', sample_length=2**13, part=0.001, batch_size=1, shuffle=True)


In [6]:
premodel = MixtureNet(layers=3, blocks=3, res_channels=32, end_channels=32, classes=256, groups=16)

In [7]:
def NLL(y_hat, y):
    assert y_hat.shape == y.shape
    assert y_hat.dim() == 3 # (batch, channels, length)
    return -(y * torch.log(y_hat) + (1 - y) * torch.log(1 - y_hat)).sum(dim=1).mean()

class DaNet(pl.LightningModule):
    def __init__(self, model_loader, noise=0.0, lr=0.0):
        super().__init__()
        self.save_hyperparameters()
        self.model = eval(model_loader)
        self.lr = lr
        self.output_size = 2**12
    
    def forward(self, x):
        return self.model(x)
    
    def training_step(self, batch, batch_idx):
        x, aux = batch
        inp = F.pad(x, (1, -1)) + torch.randn_like(x) * self.hparams.noise
        x_hat = self.forward(inp)
        loss = NLL(x_hat[-self.output_size:], x[-self.output_size:])
        self.log("train_loss", loss)
        return loss
    
    def configure_optimizers(self):
        optimizer = QHAdam(self.parameters(), lr=self.lr, nus=(0.7, 1.0))
        return {"optimizer": optimizer, "lr_scheduler": torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.5**0.001)}

In [11]:
model = DaNet(module_description(premodel), noise=0.0, lr=1e-3)

In [12]:
x = torch.randn((10, 256, 2 ** 13))
y = model(x)

tensor(-21032.4688, grad_fn=<SumBackward0>)


In [9]:
trainer = pl.Trainer(log_every_n_steps=1)
trainer.fit(model, train_dataloaders=loader)

  rank_zero_warn(
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

  | Name  | Type       | Params
-------------------------------------
0 | model | MixtureNet | 11.8 K
-------------------------------------
11.8 K    Trainable params
0         Non-trainable params
11.8 K    Total params
0.047     Total estimated model params size (MB)
  rank_zero_warn(


Epoch 0:   0%|          | 0/1 [00:00<?, ?it/s] tensor(-2214.6467, grad_fn=<SumBackward0>)
Epoch 1:   0%|          | 0/1 [00:00<?, ?it/s, loss=nan, v_num=18]        

	add_(Number alpha, Tensor other)
Consider using one of the following signatures instead:
	add_(Tensor other, *, Number alpha) (Triggered internally at  C:\actions-runner\_work\pytorch\pytorch\builder\windows\pytorch\torch\csrc\utils\python_arg_parser.cpp:1055.)
  exp_avg.mul_(beta1_adj).add_(1.0 - beta1_adj, d_p)


tensor(nan, grad_fn=<SumBackward0>)
Epoch 2:   0%|          | 0/1 [00:00<?, ?it/s, loss=nan, v_num=18]        tensor(nan, grad_fn=<SumBackward0>)
Epoch 3:   0%|          | 0/1 [00:00<?, ?it/s, loss=nan, v_num=18]        tensor(nan, grad_fn=<SumBackward0>)
Epoch 4:   0%|          | 0/1 [00:00<?, ?it/s, loss=nan, v_num=18]        tensor(nan, grad_fn=<SumBackward0>)
Epoch 5:   0%|          | 0/1 [00:00<?, ?it/s, loss=nan, v_num=18]        tensor(nan, grad_fn=<SumBackward0>)
Epoch 6:   0%|          | 0/1 [00:00<?, ?it/s, loss=nan, v_num=18]        tensor(nan, grad_fn=<SumBackward0>)
Epoch 7:   0%|          | 0/1 [00:00<?, ?it/s, loss=nan, v_num=18]        tensor(nan, grad_fn=<SumBackward0>)
Epoch 8:   0%|          | 0/1 [00:00<?, ?it/s, loss=nan, v_num=18]        tensor(nan, grad_fn=<SumBackward0>)
Epoch 9:   0%|          | 0/1 [00:00<?, ?it/s, loss=nan, v_num=18]        tensor(nan, grad_fn=<SumBackward0>)
Epoch 10:   0%|          | 0/1 [00:00<?, ?it/s, loss=nan, v_num=18]       tensor(nan

  rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")


In [None]:
for p in model.parameters():
    print(p)