In [12]:
from myai.imports import *
import torch
import torch.nn as nn
import torch.optim as optim
import torchaudio
import torchaudio.transforms as transforms
import numpy as np
import matplotlib.pyplot as plt
from torchvision import models, transforms as tv_transforms
from myai.loaders.audio import audioreadtensor

In [9]:
def load_audio(file_path, target_sr=22050)->torch.Tensor:
    try:audio, sr = torchaudio.load(file_path)
    except Exception:
        audio, sr = audioreadtensor(file_path)
    if sr != target_sr:
        resampler = transforms.Resample(sr, target_sr)
        audio = resampler(audio)
    audio = audio.mean(dim=0, keepdim=True)  # Convert to mono
    return audio

def compute_stft(audio, n_fft=2048, hop_length=512)->tuple[torch.Tensor,torch.Tensor]:
    window = torch.hann_window(n_fft)
    stft = torch.stft(audio, n_fft, hop_length, window=window, return_complex=True)
    magnitude = torch.abs(stft)
    phase = torch.angle(stft)
    return magnitude, phase

def gram_matrix(feature):
    batch_size, channels, height, width = feature.size()
    features = feature.view(channels, height * width)
    G = torch.mm(features, features.t())
    return G.div(channels * height * width)

class VGGFeatures(nn.Module):
    def __init__(self, layers):
        super().__init__()
        self.vgg = models.vgg19(pretrained=True).features
        self.layers = layers
        self.layer_map = {
            '3': 'conv1_2', '8': 'conv2_2', '17': 'conv3_4',
            '26': 'conv4_4', '35': 'conv5_4'
        }
        
    def forward(self, x):
        features = []
        for name, module in self.vgg.named_children():
            x = module(x)
            if str(name) in self.layers:
                features.append(x)
        return features

In [None]:
content_audio = load_audio("/var/mnt/ssd/Файлы/Музыка/Tracks/Thook - Leftintide.mp3")
style_audio = load_audio("/var/mnt/ssd/Файлы/Музыка/Tracks/Sorza - Visions of what could be.mp3")

# stfts
content_magnitude, content_phase = compute_stft(content_audio)
style_magnitude, style_phase = compute_stft(style_audio)

# normlize
max_magnitude:torch.Tensor = max(content_magnitude.max(), style_magnitude.max())
content_magnitude_normalized = content_magnitude / max_magnitude
style_magnitude_normalized = style_magnitude / max_magnitude

In [11]:
# make into images
content_spectrogram = content_magnitude_normalized.unsqueeze(0).repeat(1, 3, 1, 1)
style_spectrogram = style_magnitude_normalized.unsqueeze(0).repeat(1, 3, 1, 1)

# normalize
# normalize = tv_transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
content_spectrogram = znormalize(content_spectrogram)
style_spectrogram = znormalize(style_spectrogram)

# vgg
content_layer = '26'
style_layers = ['1', '6', '11', '20', '29']
vgg = VGGFeatures([content_layer] + style_layers)

Downloading: "https://download.pytorch.org/models/vgg19-dcbb9e9d.pth" to /var/home/jj/distrobox/arch/.cache/torch/hub/checkpoints/vgg19-dcbb9e9d.pth
100%|██████████| 548M/548M [07:29<00:00, 1.28MB/s] 


# OKAY I DONT HAVE ENOUGH MEMORY FOR THIS BIG SAD

In [None]:
# features
content_features = vgg(content_spectrogram)
style_features = vgg(style_spectrogram)

# separate content and style features
content_target = content_features[0].detach()
style_targets = [gram_matrix(feat.detach()) for feat in style_features[1:]]