In [4]:
import torch
from torch import nn
import os
import pickle


In [2]:
from vb import CustomDataset

In [3]:
# train_dataset = CustomDataset(train=True, target_len=2**13)
# test_dataset = CustomDataset(train=False, target_len=2**13, fixed=True)


In [6]:

# with open('train_dataset.pkl', 'wb') as f:
#     pickle.dump(train_dataset, f, pickle.HIGHEST_PROTOCOL)

# with open('test_dataset.pkl', 'wb') as f:
#     pickle.dump(test_dataset, f, pickle.HIGHEST_PROTOCOL)


100%|██████████| 824/824 [01:18<00:00, 10.55it/s]


In [5]:
with open('train_dataset.pkl', 'rb') as f:
    train_dataset = pickle.load(f)
with open('test_dataset.pkl', 'rb') as f:
    test_dataset = pickle.load(f)

In [6]:
from torch import nn
import torch


class CBP(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(CBP, self).__init__()
        self.conv = nn.Conv1d(in_channels, out_channels, 3, 1, padding=1)
        self.bn = nn.BatchNorm1d(out_channels)
        self.prelu = nn.PReLU()
    def forward(self, x, residual=None):
        x = self.conv(x)
        x = self.bn(x)
        if residual==None:
            x = self.prelu(x)
        else:
            x = self.prelu(x+residual)
        return x
        
class Pool(nn.Module):
    def __init__(self, in_channels):
        super(Pool, self).__init__()
        self.conv = nn.Conv1d(in_channels, in_channels, 2, 2, padding=0)
        self.bn = nn.BatchNorm1d(in_channels)
        self.prelu = nn.PReLU()
        
    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        x = self.prelu(x)
        
        return x
        
        
class unPool(nn.Module):
    def __init__(self, in_channels):
        super(unPool, self).__init__()
        self.tconv = nn.ConvTranspose1d(in_channels, in_channels, 2, 2)
        self.bn = nn.BatchNorm1d(in_channels)
        self.prelu = nn.PReLU()
        
    def forward(self, x):
        x = self.tconv(x)
        x = self.bn(x)
        x = self.prelu(x)
        
        return x
          
        
        
        
        
        

class Encoder(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(Encoder, self).__init__()
        
        self.residual = nn.Conv1d(in_channels, out_channels, 1)  # Skip connection
        self.cbp11 = CBP(in_channels, out_channels)
        self.cbp12 = CBP(out_channels, out_channels)
        
        self.cbp21 = CBP(out_channels, out_channels)
        self.cbp22 = CBP(out_channels, out_channels)
        
        self.pool = Pool(out_channels)
        
        
    def forward(self, x):
        residual = self.residual(x)
        x = self.cbp11(x)
        residual = self.cbp12(x, residual)
        
        x = self.cbp21(residual)
        skip = self.cbp22(x, residual)
        
        x = self.pool(skip)
        
        return x, skip

    
    
    
    
    
class Decoder(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(Decoder, self).__init__()
        self.unpool = unPool(in_channels)
        self.residual = nn.Conv1d(in_channels, out_channels, 1)  # Skip connection
        
        self.cbp11 = CBP(in_channels*2, out_channels)
        self.cbp12 = CBP(out_channels, out_channels)
        
        self.cbp21 = CBP(out_channels, out_channels)
        self.cbp22 = CBP(out_channels, out_channels)
        
    def forward(self, x, skip):
        
    
        x = self.unpool(x)
        residual = self.residual(x)
        x = self.cbp11(torch.cat([x, skip], dim=1))
        residual = self.cbp12(x, residual)
    
        x = self.cbp21(residual)
        x = self.cbp22(x, residual)
        
        return x


class Feature_Encoder(nn.Module):
    def __init__(self):
        super(Feature_Encoder, self).__init__()
        # Encoder blocks
        self.encoder_block_1 = Encoder(1, 16)
        self.encoder_block_2 = Encoder(16, 32)
        self.encoder_block_3 = Encoder(32, 64)
        self.encoder_block_4 = Encoder(64, 128)
        self.encoder_block_5 = Encoder(128, 256)

        # Decoder blocks
        self.decoder_block_5 = Decoder(256, 128)
        self.decoder_block_4 = Decoder(128, 64)
        self.decoder_block_3 = Decoder(64, 32)
        self.decoder_block_2 = Decoder(32, 16)
        self.decoder_block_1 = Decoder(16, 16)
        
        self.conv1 = nn.Conv1d(16, 16, 3, 1, padding=1)  # Remove dilation
        self.conv2 = nn.Conv1d(16, 32, 3, 1, padding=1)  # Remove dilation
        self.conv3 = nn.Conv1d(32, 64, 3, 1, padding=1)
        self.conv4 = nn.Conv1d(64, 128, 3, 1, padding=1)
        self.conv7 = nn.Conv1d(128, 1, 1, 1)

        
        self.bn1 = nn.BatchNorm1d(16)
        self.bn2 = nn.BatchNorm1d(32)
        self.bn3 = nn.BatchNorm1d(64)
        self.bn4 = nn.BatchNorm1d(128)
        
        self.prelu1 = nn.PReLU()
        self.prelu2 = nn.PReLU()
        self.prelu3 = nn.PReLU()
        self.prelu4 = nn.PReLU()

    def forward(self, x):
        out, skip1 = self.encoder_block_1(x)
        out, skip2 = self.encoder_block_2(out)
        out, skip3 = self.encoder_block_3(out)
        out, skip4 = self.encoder_block_4(out)
        out, skip5 = self.encoder_block_5(out)
        
        out = self.decoder_block_5(out, skip5)
        out = self.decoder_block_4(out, skip4)
        out = self.decoder_block_3(out, skip3)
        out = self.decoder_block_2(out, skip2)
        out = self.decoder_block_1(out, skip1)

        out = self.conv1(out)
        out = self.bn1(out)
        out = self.prelu1(out)
        out = self.conv2(out)
        out = self.bn2(out)
        out = self.prelu2(out)
        out = self.conv3(out)
        out = self.bn3(out)
        out = self.prelu3(out)
        out = self.conv4(out)
        out = self.bn4(out)
        out = self.prelu4(out)

        out = self.conv7(out)

        return out

model = Feature_Encoder()
sample_input = torch.randn(64, 1, 8192)
model(sample_input)
trainable_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
trainable_parameters

1860247

In [7]:
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True, pin_memory=True,drop_last=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=64, shuffle=False, pin_memory=True,drop_last=True)
device = 'cuda'

In [8]:

from torch.utils.tensorboard import SummaryWriter

device = 'cuda'
model = Feature_Encoder().to(device)
loss_fn = nn.MSELoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5)

writer = SummaryWriter('boards/1_8M_unet')
x_back = None
l_list = []
trainable_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
print('trainable parameters:', trainable_parameters)
for epoch in range(0,100000000000):
    t_loss = 0
    model.train()
    for clean, noisy in train_loader:
        clean, noisy = clean.to(device), noisy.to(device)
        
        pred = model(noisy)
        
        loss = loss_fn(pred, clean)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        t_loss+=loss.item()
    # l_list.append(t_loss)
    t_loss /= len(train_loader)
    
    
    
    model.eval()
    with torch.no_grad():
        t_tloss = 0
        for clean, noisy in test_loader:
            clean, noisy = clean.to(device), noisy.to(device)
            pred = model(noisy)
            loss = loss_fn(pred, clean)
            t_tloss+=loss.item()
        # l_list.append(t_tloss)
        t_tloss /= len(test_loader)
        
    writer.add_scalars('run_14h', {'train_loss':t_loss,
                                        'test_loss':t_tloss}, epoch)
    writer.flush()
    if epoch%10==1:
        print(epoch, t_loss, t_tloss)
        try:
            os.mkdir("ckpts/1_8M_unet")
        except:
            pass
        torch.save(model.state_dict(), f'ckpts/1_8M_unet/ckpt_{epoch}_.pt')

writer.close()

trainable parameters: 1860247
1 0.013276073181380828 0.008472477202303708
11 0.0016637162344219783 0.0012572383129736409
21 0.001042173466218325 0.0006198301271069795
31 0.0008076944116813441 0.0004334378803226476
41 0.0006449492231089001 0.00029206398418561247
51 0.000568528243942031 0.00024260403491401425
61 0.000530562955504542 0.00023825467239172818
71 0.00048648909610670267 0.00021541449556631656


KeyboardInterrupt: 