In [1]:
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt # for making figures
%matplotlib inline

In [2]:
# download the names.txt file from github
!wget https://raw.githubusercontent.com/karpathy/makemore/master/names.txt

--2023-03-28 15:45:34--  https://raw.githubusercontent.com/karpathy/makemore/master/names.txt
正在解析主机 raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.109.133, 185.199.111.133, 185.199.108.133
正在连接 raw.githubusercontent.com (raw.githubusercontent.com)|185.199.109.133|:443... 已连接。
无法建立 SSL 连接。


In [3]:

class CausalConv1D(nn.Module):
    def __init__(self,inChan,outChan,dilation=1,nonlinearity='linear'):
        super().__init__()
        K=2
        # self.receive_field=K+(K-1)*(dilation-1)

        self.conv=nn.Conv1d(inChan,outChan,K,dilation=dilation,bias=False)

        nn.init.kaiming_normal_(self.conv.weight,  nonlinearity=nonlinearity)

    def forward(self, x):
        # x=F.pad(x, (self.receive_field-1,0),mode='constant', value=0)
        self.out=self.conv(x)

        return  self.out

class Block(nn.Module):
    def __init__(self,inChan,dilation=1):
        super(Block, self).__init__()

        self.causul_conv=CausalConv1D(inChan,inChan,dilation,nonlinearity='tanh')  #for tanh
        self.causul_conv2 = CausalConv1D(inChan, inChan, dilation) #for sigma

        self.tanh=nn.Tanh()
        self.sigmoid=nn.Sigmoid()


        self.output_conv=nn.Conv1d(inChan,inChan,1)
        self.skip_conv = nn.Conv1d(inChan, inChan, 1)


    def forward(self, x,debug=False):
        h=self.causul_conv(x)
        h1=self.causul_conv2(x)

        th=self.tanh(h)
        # sh=self.sigmoid(h1)
        w=th*h1

        o=self.output_conv(w)+x[:,:,-h.shape[2]:]
        s=self.skip_conv(w)
        return o,s



class WaveNet(nn.Module):
    def __init__(self,classes,resChan,stacks,layers):
        super(WaveNet, self).__init__()

        dilations=[2**l for l in range(layers)]*stacks

        self.receive_filed=1
        for i,d in enumerate(dilations):
            self.receive_filed+=d


        self.emb_layer=nn.Embedding(classes,resChan)
        # bug,没有加这个ModuleList,这些block对于的参数网络
        self.layers=nn.ModuleList([Block(resChan,d) for d in dilations])


        self.post_layer=nn.Sequential(
            nn.ReLU(),
            nn.Conv1d(resChan,resChan,1),
            nn.ReLU(),
            nn.Conv1d(resChan, classes, 1),
        )
    def forward(self,x):
        # x.shape =(B,T,classes)
        B,T=x.shape
        remain=T-self.receive_filed+1

        emb=self.emb_layer(x).transpose(1,2)

        h=emb

        output=0
        for layer in self.layers:
            h,s=layer(h)
            output+=s[:,:,-remain:]

        self.skipout=output
        # (B,classes,T)
        logit=self.post_layer(output)

        return torch.squeeze(logit)


In [4]:
class DataSet():
    def __init__(self):
        self.words = open('names.txt').read().splitlines()
        random.shuffle(self.words)

        self.chs = sorted(set(''.join(self.words)))
        self.stoi = {c: i + 1 for i, c in enumerate(self.chs)}
        self.stoi['.'] = 0
        self.itos = {i: c for c, i in self.stoi.items()}
        self.VSIZE=len(self.itos)
    def index2Str(self,l):
        return "".join([self.itos[id] for id in l])
    def _build_dataset(self,words,contextSize):
        Xs,Ys=[],[]
        for word in words:
            context = [0] * contextSize
            for ch in word+".":
                idx=self.stoi[ch]
                Xs.append(context)
                Ys.append(idx)
                context=context[1:]+[idx]
        return torch.tensor(Xs),torch.tensor(Ys)

    def getData(self,tag,contextSize):
        n1 = int(len(self.words) * 0.8)
        n2 = int(len(self.words) * 0.9)

        if tag=='train':
            X,Y=self._build_dataset(self.words[0:n1],contextSize)
        elif tag=='val':
            X, Y = self._build_dataset(self.words[n1:n2],contextSize)
        elif tag=='test':
            X, Y = self._build_dataset(self.words[n2:],contextSize)

        print(f"{tag}: X: {X.shape}, Y: {Y.shape}")

        return X,Y

In [64]:
hparams={
        "contextSize":8,
        "embSize":16,
        "hiddenSize":100,
        "steps":20000,
        "batch_size":32,
        "Wgain":5/3,
        "softmax_gain":0.01
}

random.seed(42)
torch.manual_seed(2147483647)


ds=DataSet()
Xtr,Ytr=ds.getData('train',hparams["contextSize"])
Xdev,Ydev=ds.getData('val',hparams["contextSize"])
Xtest,Ytest=ds.getData('test',hparams["contextSize"])


embSize=hparams["embSize"]
hiddenSize=hparams["hiddenSize"]
Wgain=hparams["Wgain"]
softmax_gain=hparams["softmax_gain"]

model=WaveNet(ds.VSIZE,hiddenSize,1,3)

def zeros_grad(model):
    for p in model.parameters():
       p.grad=None
model.load_state_dict(torch.load("wavenet_model_colab.pth",map_location=torch.device('cpu')),strict=False)

train: X: torch.Size([182618, 8]), Y: torch.Size([182618])
val: X: torch.Size([22644, 8]), Y: torch.Size([22644])
test: X: torch.Size([22892, 8]), Y: torch.Size([22892])


<All keys matched successfully>

In [65]:
ncount=sum([p.nelement() for p in model.parameters()])
print(f"# {ncount}")

print("--------------training---------------")
print(f"hparams:{hparams}")
print()

model.train()

# 196127
--------------training---------------
hparams:{'contextSize': 8, 'embSize': 16, 'hiddenSize': 100, 'steps': 20000, 'batch_size': 32, 'Wgain': 1.6666666666666667, 'softmax_gain': 0.01}



WaveNet(
  (emb_layer): Embedding(27, 100)
  (layers): ModuleList(
    (0): Block(
      (causul_conv): CausalConv1D(
        (conv): Conv1d(100, 100, kernel_size=(2,), stride=(1,), bias=False)
      )
      (causul_conv2): CausalConv1D(
        (conv): Conv1d(100, 100, kernel_size=(2,), stride=(1,), bias=False)
      )
      (tanh): Tanh()
      (sigmoid): Sigmoid()
      (output_conv): Conv1d(100, 100, kernel_size=(1,), stride=(1,))
      (skip_conv): Conv1d(100, 100, kernel_size=(1,), stride=(1,))
    )
    (1): Block(
      (causul_conv): CausalConv1D(
        (conv): Conv1d(100, 100, kernel_size=(2,), stride=(1,), dilation=(2,), bias=False)
      )
      (causul_conv2): CausalConv1D(
        (conv): Conv1d(100, 100, kernel_size=(2,), stride=(1,), dilation=(2,), bias=False)
      )
      (tanh): Tanh()
      (sigmoid): Sigmoid()
      (output_conv): Conv1d(100, 100, kernel_size=(1,), stride=(1,))
      (skip_conv): Conv1d(100, 100, kernel_size=(1,), stride=(1,))
    )
    (2): Bloc

In [66]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model=model.to(device)
@torch.no_grad()
def splitLoss(split):
    d={
        'train':(Xtr,Ytr),
        'dev':(Xdev,Ydev),
        'tetst':(Xtest,Ytest)
    }
    xs,ys=d[split]
    if split=="train":
        model.train()
    else:
        model.eval()
    loss=F.cross_entropy(model(xs.to(device)),ys.to(device)).item()
    return loss

In [91]:
grads=[]
def zeros_grads_buffer():
    grads.clear()
    for p in model.parameters():
        grads.append(torch.zeros_like(p.data))

zeros_grads_buffer()
batchInterval=200
for step in range(500):
    idx=torch.randint(0,Xtr.shape[0],(hparams["batch_size"],),generator=None)
    xs_batch=Xtr[idx].to(device)
    ys_batch=Ytr[idx].to(device)


    logit=model(xs_batch)
    nll=F.cross_entropy(logit,ys_batch)

    # model.zero_grads()
    zeros_grad(model)
    # model.skipout.retain_grad()
    nll.backward()

    lr = 0.001  if step < 100000 else 0.01

    # 积累grads
    for i,p in enumerate(model.parameters()):
       if p.grad == None:
           continue
       grads[i]+=p.grad

    if step %batchInterval==batchInterval-1:
        print('grad summary')
        for i,p in enumerate(model.parameters()):
            if p.grad == None:
                continue
            print(grads[i].abs().mean()/batchInterval)
            p.data-=lr*grads[i]
        zeros_grads_buffer()

    # print(nll.item(),(model.skipout.grad/model.skipout).abs().mean())
    # print(nll.item(),(pre-model.skipout).abs().mean())
    pre=model.skipout
    if( step%20==0):
        train_loss=nll.item()
        dev_loss=splitLoss('dev')
        # print(f'{step:7d}/{hparams["steps"]:7d} train_loss {train_loss:.4f},dev_loss {dev_loss:.4f}')
        model.train()
print("--------------training end---------------")

grad summary
tensor(0.0003)
tensor(0.0004)
tensor(0.0006)
tensor(0.0009)
tensor(0.0008)
tensor(0.0003)
tensor(0.0002)
tensor(0.0004)
tensor(0.0006)
tensor(0.0006)
tensor(0.0003)
tensor(0.0005)
tensor(0.0002)
tensor(0.0003)
tensor(0.0004)
tensor(0.0006)
tensor(0.0002)
tensor(0.0006)
tensor(0.0005)
tensor(0.0009)
tensor(0.0014)
grad summary
tensor(0.0003)
tensor(0.0003)
tensor(0.0006)
tensor(0.0009)
tensor(0.0006)
tensor(0.0003)
tensor(0.0002)
tensor(0.0004)
tensor(0.0005)
tensor(0.0006)
tensor(0.0002)
tensor(0.0005)
tensor(0.0002)
tensor(0.0003)
tensor(0.0004)
tensor(0.0006)
tensor(0.0002)
tensor(0.0005)
tensor(0.0004)
tensor(0.0008)
tensor(0.0009)
--------------training end---------------


In [71]:
final_train_loss=splitLoss('train')
final_dev_loss=splitLoss('dev')
print(f"train loss {final_train_loss:.4f}, dev loss {final_dev_loss:.4f}")
torch.save(model.state_dict(), "wavenet_model.pth")

train loss 1.8206, dev loss 2.0102


In [72]:
grads=[]
for p in model.parameters():
    grads.append(torch.zeros_like(p.data))
#梯度真的为0了吗？
for step in range(5300):
    idx=torch.randint(0,Xtr.shape[0],(hparams["batch_size"],),generator=None)
    xs_batch=Xtr[idx].to(device)
    ys_batch=Ytr[idx].to(device)


    logit=model(xs_batch)
    nll=F.cross_entropy(logit,ys_batch)

    # model.zero_grads()
    zeros_grad(model)
    nll.backward()


    for j,p in enumerate(model.parameters()):
       if p.grad == None:
           continue
       grads[j]+=p.grad
print("--------------training end---------------")

--------------training end---------------


In [79]:
for g in grads:
    print((g.abs().mean()))

tensor(0.3734)
tensor(0.4579)
tensor(0.7278)
tensor(1.1419)
tensor(0.9706)
tensor(0.3301)
tensor(0.2190)
tensor(0.5266)
tensor(0.6734)
tensor(0.7403)
tensor(0.3161)
tensor(0.6566)
tensor(0.2190)
tensor(0.3692)
tensor(0.5054)
tensor(0.)
tensor(0.)
tensor(0.7178)
tensor(0.2190)
tensor(0.7374)
tensor(0.7063)
tensor(1.1287)
tensor(2.0822)


In [48]:
grads

[]