In [1]:
#加载全部数据到内存中
def get_data():
    from datasets import load_dataset
    import numpy as np

    #加载
    dataset = load_dataset('lansinuote/gen.2.chorales', split='train')

    #加载为numpy数据
    data = np.empty((229, 4, 2, 16, 84), dtype=np.float32)
    for i in range(len(dataset)):
        data[i] = dataset[i]['data']

    return data


data = get_data()

data.shape, data.min(), data.max()

Using custom data configuration lansinuote--gen.2.chorales-2bf7c47eabbdde89
Found cached dataset parquet (/root/.cache/huggingface/datasets/lansinuote___parquet/lansinuote--gen.2.chorales-2bf7c47eabbdde89/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec)


((229, 4, 2, 16, 84), -1.0, 1.0)

In [2]:
import torch

loader = torch.utils.data.DataLoader(
    dataset=data,
    batch_size=64,
    shuffle=True,
    drop_last=True,
)

len(loader), next(iter(loader)).shape

(3, torch.Size([64, 4, 2, 16, 84]))

In [3]:
import music21


#工具类,不重要
class Show():
    #工具函数,不重要
    def __merge_note(self, note, duration=None):
        import numpy as np

        if duration is None:
            duration = np.full(note.shape, fill_value=0.25, dtype=np.float32)

        #从前往后遍历
        for i in range(len(note) - 1):
            j = i + 1

            #判断相连的两个note是否相同,并且duration相加不大于1.0
            if note[i] == note[j] and duration[i] + duration[j] <= 1.0:

                #duration合并
                duration[i] += duration[j]

                #删除重复的note
                note = np.delete(note, j, axis=0)
                duration = np.delete(duration, j, axis=0)

                #递归调用
                return self.__merge_note(note, duration)

        return note, duration

    #工具函数,不重要
    def __save_to_mid(self, data):
        #data -> [32, 4]
        stream = music21.stream.Score()
        stream.append(music21.tempo.MetronomeMark(number=66))

        for i in range(4):
            channel = music21.stream.Part()

            notes, durations = self.__merge_note(data[:, i])
            notes, durations = notes.tolist(), durations.tolist()
            for n, d in zip(notes, durations):
                note = music21.note.Note(n)
                note.duration = music21.duration.Duration(d)
                channel.append(note)

            stream.append(channel)

        stream.write('midi', fp='./datas/temp.midi')

    def __call__(self, data):
        #[4, 2, 16, 84] -> [4, 2, 16] -> [32, 4]
        data = data.argmax(dim=-1).reshape(32, 4)
        data = data.to('cpu').detach().numpy()
        self.__save_to_mid(data)

        f = music21.midi.MidiFile()
        f.open('./datas/temp.midi')
        f.read()
        f.close()
        music21.midi.translate.midiFileToStream(f).show('midi')


show = Show()

for _ in range(3):
    show(next(iter(loader))[0])

In [4]:
def get_gen_track():
    return torch.nn.Sequential(
        torch.nn.Linear(4 * 32, 1024),
        torch.nn.BatchNorm1d(1024),
        torch.nn.ReLU(inplace=True),
        torch.nn.Unflatten(unflattened_size=(512, 2, 1), dim=1),
        torch.nn.ConvTranspose2d(512,
                                 512,
                                 kernel_size=(2, 1),
                                 stride=(2, 1),
                                 padding=0),
        torch.nn.BatchNorm2d(512),
        torch.nn.ReLU(inplace=True),
        torch.nn.ConvTranspose2d(512,
                                 256,
                                 kernel_size=(2, 1),
                                 stride=(2, 1),
                                 padding=0),
        torch.nn.BatchNorm2d(256),
        torch.nn.ReLU(inplace=True),
        torch.nn.ConvTranspose2d(256,
                                 256,
                                 kernel_size=(2, 1),
                                 stride=(2, 1),
                                 padding=0),
        torch.nn.BatchNorm2d(256),
        torch.nn.ReLU(inplace=True),
        torch.nn.ConvTranspose2d(256,
                                 256,
                                 kernel_size=(1, 7),
                                 stride=(1, 7),
                                 padding=0),
        torch.nn.BatchNorm2d(256),
        torch.nn.ReLU(inplace=True),
        torch.nn.ConvTranspose2d(256,
                                 1,
                                 kernel_size=(1, 12),
                                 stride=(1, 12),
                                 padding=0),
        torch.nn.Unflatten(unflattened_size=(1, 1), dim=1),
    )


get_gen_track()(torch.randn(2, 128)).shape

torch.Size([2, 1, 1, 16, 84])

In [5]:
def get_gen_block():
    return torch.nn.Sequential(
        torch.nn.Unflatten(unflattened_size=(32, 1, 1), dim=1),
        torch.nn.ConvTranspose2d(32,
                                 1024,
                                 kernel_size=(2, 1),
                                 stride=(1, 1),
                                 padding=0), torch.nn.BatchNorm2d(1024),
        torch.nn.ReLU(inplace=True),
        torch.nn.ConvTranspose2d(1024,
                                 32,
                                 kernel_size=(2 - 1, 1),
                                 stride=(1, 1),
                                 padding=0), torch.nn.BatchNorm2d(32),
        torch.nn.ReLU(inplace=True), torch.nn.Flatten(start_dim=2))


get_gen_block()(torch.randn(2, 32)).shape

torch.Size([2, 32, 2])

In [6]:
class GEN(torch.nn.Module):

    def __init__(self):
        super().__init__()

        self.gen_chord = get_gen_block()

        self.gen_melody = torch.nn.ModuleList(
            [get_gen_block() for _ in range(4)])

        self.gen_track = torch.nn.ModuleList(
            [get_gen_track() for _ in range(4)])

    def forward(self, chord, style, melody, groove):
        #chord -> [b, 32]
        #style -> [b, 32]
        #melody -> [b, 4, 32]
        #groove -> [b, 4, 32]

        #[b, 32] -> [b, 32, 2]
        out_chord = self.gen_chord(chord)

        out_i = []
        for i in range(2):

            out_j = []
            for j in range(4):

                #[b, 32] -> [b, 32, 2] -> [b, 32]
                out_melody = self.gen_melody[j](melody[:, j])[:, :, i]

                #[b, 32+32+32+32] -> [b, 128]
                out = torch.cat(
                    [out_chord[:, :, i], style, out_melody, groove[:, j]],
                    dim=1)

                #[b, 128] -> [b, 1, 1, 16, 84]
                out = self.gen_track[j](out)

                out_j.append(out)

            #[b, 1*4, 1, 16, 84] -> [b, 4, 1, 16, 84]
            out_i.append(torch.cat(out_j, dim=1))

        #[b, 4, 1*2, 16, 84] -> [b, 4, 2, 16, 84]
        out = torch.cat(out_i, dim=2)

        return out


gen = GEN()

gen(torch.randn(2, 32), torch.randn(2, 32), torch.randn(2, 4, 32),
    torch.randn(2, 4, 32)).shape

torch.Size([2, 4, 2, 16, 84])

In [7]:
def get_cls():
    return torch.nn.Sequential(
        torch.nn.Conv3d(4, 128, (2, 1, 1), (1, 1, 1), padding=0),
        torch.nn.LeakyReLU(0.3, inplace=True),
        torch.nn.Conv3d(128, 128, (2 - 1, 1, 1), (1, 1, 1), padding=0),
        torch.nn.LeakyReLU(0.3, inplace=True),
        torch.nn.Conv3d(128, 128, (1, 1, 12), (1, 1, 12), padding=0),
        torch.nn.LeakyReLU(0.3, inplace=True),
        torch.nn.Conv3d(128, 128, (1, 1, 7), (1, 1, 7), padding=0),
        torch.nn.LeakyReLU(0.3, inplace=True),
        torch.nn.Conv3d(128, 128, (1, 2, 1), (1, 2, 1), padding=0),
        torch.nn.LeakyReLU(0.3, inplace=True),
        torch.nn.Conv3d(128, 128, (1, 2, 1), (1, 2, 1), padding=0),
        torch.nn.LeakyReLU(0.3, inplace=True),
        torch.nn.Conv3d(128, 2 * 128, (1, 4, 1), (1, 2, 1), padding=(0, 1, 0)),
        torch.nn.LeakyReLU(0.3, inplace=True),
        torch.nn.Conv3d(2 * 128,
                        4 * 128, (1, 3, 1), (1, 2, 1),
                        padding=(0, 1, 0)),
        torch.nn.LeakyReLU(0.3, inplace=True),
        torch.nn.Flatten(),
        torch.nn.Linear(4 * 128, 1024),
        torch.nn.LeakyReLU(0.3, inplace=True),
        torch.nn.Linear(1024, 1),
    )


cls = get_cls()

cls(torch.randn(2, 4, 2, 16, 84))

tensor([[0.0234],
        [0.0236]], grad_fn=<AddmmBackward0>)

In [8]:
def set_requires_grad(model, requires_grad):
    for param in model.parameters():
        param.requires_grad_(requires_grad)

def wasserstein(pred, label):
    return -(pred * label).mean()


optimizer_cls = torch.optim.Adam(cls.parameters(), lr=1e-3, betas=(0.5, 0.9))
optimizer_gen = torch.optim.Adam(gen.parameters(), lr=1e-3, betas=(0.5, 0.9))

device = 'cuda' if torch.cuda.is_available() else 'cpu'

gen.to(device)
cls.to(device)

device

'cuda'

In [9]:
def get_gradient_penalty(real, fake):
    #real -> [64, 4, 2, 16, 84]
    #fake -> [64, 4, 2, 16, 84]

    r = torch.rand((64, 1, 1, 1, 1), device=device)
    r.requires_grad = True

    #[64, 4, 2, 16, 84]
    merge = r * real + (1 - r) * fake

    #[64, 4, 2, 16, 84] -> [64, 1]
    pred_merge = cls(merge)

    grad = torch.autograd.grad(inputs=merge,
                               outputs=pred_merge,
                               grad_outputs=torch.ones(64, 1, device=device),
                               create_graph=True,
                               retain_graph=True)

    #[64, 4, 2, 16, 84] -> [64, 10752]
    grad = grad[0].reshape(64, -1)

    #[64, 10752] -> [64]
    grad = grad.norm(p=2, dim=1)

    #[64] -> scala
    return (1 - grad).pow(2).mean()


get_gradient_penalty(torch.randn(64, 4, 2, 16, 84, device=device),
                     torch.randn(64, 4, 2, 16, 84, device=device))

tensor(0.9996, device='cuda:0', grad_fn=<MeanBackward0>)

In [10]:
def train_cls():
    set_requires_grad(cls, True)
    set_requires_grad(gen, False)
    
    #得到三份数据
    real = next(iter(loader)).to(device)

    with torch.no_grad():
        cord = torch.randn(64, 32, device=device)
        style = torch.randn(64, 32, device=device)
        melody = torch.randn(64, 4, 32, device=device)
        groove = torch.randn(64, 4, 32, device=device)
        fake = gen(cord, style, melody, groove)

    #分别计算
    pred_fake = cls(fake)
    pred_real = cls(real)

    #求loss,加权求和
    loss_fake = wasserstein(pred_fake, -torch.ones(64, 1, device=device))
    loss_real = wasserstein(pred_real, torch.ones(64, 1, device=device))
    loss_grad = get_gradient_penalty(real, fake)

    loss = loss_fake + loss_real + loss_grad * 10

    loss.backward()
    optimizer_cls.step()
    optimizer_cls.zero_grad()

    return loss.item()


train_cls()

9.99548053741455

In [11]:
def train_gen():
    set_requires_grad(cls, False)
    set_requires_grad(gen, True)
    
    cord = torch.randn(64, 32, device=device)
    style = torch.randn(64, 32, device=device)
    melody = torch.randn(64, 4, 32, device=device)
    groove = torch.randn(64, 4, 32, device=device)

    fake = gen(cord, style, melody, groove)
    fake_pred = cls(fake)

    loss = wasserstein(fake_pred, torch.ones(64, 1, device=device))
    loss.backward()
    optimizer_gen.step()
    optimizer_gen.zero_grad()

    return loss.item()


train_gen()

0.015252873301506042

In [12]:
def train():
    for epoch in range(2_0000):
        for _ in range(5):
            loss_cls = train_cls()

        loss_gen = train_gen()

        if epoch % 2000 == 0:
            print(epoch, loss_cls, loss_gen)

            #这里的b必须要大于1,否则BatchNorm层的计算会出错
            chord = torch.rand(2, 32, device=device)
            style = torch.rand(2, 32, device=device)
            melody = torch.rand(2, 4, 32, device=device)
            groove = torch.rand(2, 4, 32, device=device)

            #[2, 4, 2, 16, 84]
            pred = gen(chord, style, melody, groove)
            show(pred[0])


local_training = True

if local_training:
    train()

0 -88.02761840820312 365.5087890625


2000 -18.66831398010254 4.117001533508301


4000 -18.55518341064453 -1.6856918334960938


6000 -16.17269515991211 -4.4504075050354


8000 -12.522842407226562 -3.564983367919922


10000 -13.41166877746582 -3.798163890838623


12000 -9.745172500610352 -6.956740379333496


14000 -8.481831550598145 -3.7079086303710938


16000 -8.238203048706055 -1.697916030883789


18000 -7.125642776489258 -2.4775843620300293


In [13]:
from transformers import PreTrainedModel, PretrainedConfig


class Model(PreTrainedModel):
    config_class = PretrainedConfig

    def __init__(self, config):
        super().__init__(config)
        self.cls = cls.to('cpu')
        self.gen = gen.to('cpu')


if local_training:
    #保存训练好的模型到hub
    Model(PretrainedConfig()).push_to_hub(
        repo_id='lansinuote/gen.7.musegan',
        use_auth_token=open('/root/hub_token.txt').read().strip())

pytorch_model.bin:   0%|          | 0.00/32.3M [00:00<?, ?B/s]

Upload 1 LFS files:   0%|          | 0/1 [00:00<?, ?it/s]

In [14]:
#加载训练好的模型
gen = Model.from_pretrained('lansinuote/gen.7.musegan').gen
with torch.no_grad():
    #这里的b必须要大于1,否则BatchNorm层的计算会出错
    chord = torch.rand(2, 32)
    style = torch.rand(2, 32)
    melody = torch.rand(2, 4, 32)
    groove = torch.rand(2, 4, 32)

    #[2, 4, 2, 16, 84]
    pred = gen(chord, style, melody, groove)
    show(pred[0])

Downloading (…)"pytorch_model.bin";:   0%|          | 0.00/32.3M [00:00<?, ?B/s]