In [1]:
import eGNN
import data_read
import protein_residues
import torch
import torch.nn as nn
import numpy as np
import math
import pickle as pkl
from tqdm import tqdm
from Bio.SeqUtils import seq1, seq3
from Bio.PDB import PDBIO, StructureBuilder
import gc
device = 'cuda' if torch.cuda.is_available() else 'cpu'
torch.set_default_device("cpu")
pdb_dir ='/mnt/rna01/nico/dynamics_project/learn/eGNN/data/dompdb/'
list_path = "/mnt/rna01/nico/dynamics_project/learn/eGNN/data/clean_pdb_id.txt"

In [2]:
data_std = 10 #The std of protein coords is determined to be ~10 as indicated by data_processing.ipynb

# New Utils

Define some functions that were not placed inside the Python scripts. 
generate_res_object => Creates a dictionary from the pdb_seq {str} and the coordinates of the backbone {torch.tensor}
generate_pdb        => Creates a PDB file from residues {dictionary} with name pdb_id {string}  

In [3]:
def generate_res_object(pdb_seq,coords):
    residues = []
    for i ,aa in enumerate(pdb_seq, start=0):
        if aa == "G": #If the amino acid is a glycine, do not add CB as there is no CB in glycine
            res = {"name":seq3(aa),"atoms":[("N", coords[i,0,:].tolist()), ("CA", coords[i,1,:].tolist()), ("C", coords[i,2,:].tolist())]}
        else:
            res = {"name":seq3(aa),"atoms":[("N", coords[i,0,:].tolist()), ("CA", coords[i,1,:].tolist()), ("C", coords[i,2,:].tolist()),("CB", coords[i,3,:].tolist())]}
        residues.append(res)
    return residues

In [4]:
def generate_pdb(pdb_id,residues):
    builder = StructureBuilder.StructureBuilder()
    
    # Create a structure object
    builder.init_structure("Predicted eGNN Backbone ")
    builder.init_model(0)
    builder.init_chain("A")  # Single chain "A"
    builder.init_seg(" ")
    
    for res_id, residue in enumerate(residues, start=1):
        builder.init_residue(residue["name"], " ", res_id, " ")
    
        # Add atoms to the residue
        for atom_name, coords in residue["atoms"]:
            builder.init_atom(atom_name, coords, 1.0, 1.0, " ", atom_name, res_id, atom_name[0])

    structure = builder.get_structure()
    io = PDBIO()
    io.set_structure(structure)
    io.save(pdb_id+".pdb")
    

# Objects
Create model from multiple Pytorch objects. prot_eGNN is the final model which takes in 

In [5]:
class Diffusion(nn.Module):

    def __init__(self, T, b_initial, b_final, device):
        super().__init__()
        self.T = T #Maximum Timestep
        self.beta = torch.linspace(b_initial,b_final,T).to(device) #Define linear beta schedule
        
    def _CoMGaussNoise(self, x_t1, t1, t2): #Diffusion w.r.t., to the centre of mass as proposed by the E(3) diffusion https://arxiv.org/abs/2203.17003
        if t2 == 0:
            a_mul = torch.prod(1-self.beta[0])
        else:
            a_mul = torch.prod(1-self.beta[t1:t2+1])
        eps = torch.normal(mean=torch.zeros_like(x_t1),std=torch.ones_like(x_t1))
        x1_mean = torch.mean(x_t1.flatten(end_dim=-2),dim=-2, keepdim=True)[None,:]
        eps = eps - x1_mean 
        x_t2 = torch.sqrt(a_mul)*x_t1 + torch.sqrt(1-a_mul)*eps
        return x_t2, eps

    def _GaussNoise(self, x_t1, t1, t2): #Normal diffusion as proposed by the DDPM paper
        if t2 == 0:
            a_mul = torch.prod(1-self.beta[0])
        else:
            a_mul = torch.prod(1-self.beta[t1:t2+1])
        eps = torch.normal(mean=torch.zeros_like(x_t1),std=torch.ones_like(x_t1))
        x_t2 = torch.sqrt(a_mul)*x_t1 + torch.sqrt(1-a_mul)*eps
        #Same output has the same shape as the input
        return x_t2, eps
        
    def forward(self, x, h, t2):
        x_perturbed, x_eps = self._GaussNoise(x, 0, t2)
        h_perturbed, h_eps = self._GaussNoise(h,0,t2)
        return x_perturbed, x_eps, h_perturbed, h_eps

In [6]:
def input_generation(pdb_path,edge_device="cuda"): #Takes a PDB path and generate a torch tensor of the backbone coordinates (frames), edges defined by the node indexes (edge_id) and the 1 letter representation of each amino acid as defined by the 
                                                   # IUPAC convention https://iupac.qmul.ac.uk/AminoAcid/A2021.html
    frames, seq = data_read.get_backbone(pdb_path)
    frames = torch.from_numpy(frames[0]); seq = seq[0]
    frames = frames.to(torch.float32)
    n = len(seq)
    seq_id = data_read.encode(seq[0])
    
    
    #Assumes fully connected graph
    row =torch.arange(0,n).repeat_interleave(n).to(device)
    col =torch.arange(0,n).repeat(n).to(device)
    edge_id=(row,col)

    return frames, seq, edge_id

In [7]:
#Wrapper class that combines an eGNN and a Diffusion module
class prot_eGNN(nn.Module):
    
    def __init__(self, process_device, gnn_device, embed_dim, sc_dim, in_node_nf, hidden_nf, out_node_nf,
                 T, b_initial=0.0001, b_final=0.02, 
                 in_edge_nf=0, act_fn=nn.SiLU(), n_layers=6, residual=True, attention=True, normalize=False, tanh=False                 
                ):
        super().__init__()

        self.sc_dim = sc_dim
        self.process_device = process_device
        self.gnn_device = gnn_device
        self.T = T
        
        self.embedding = nn.Embedding(20,embed_dim)
        self.diffusion = Diffusion(T, b_initial=0.0001, b_final=0.02, device=self.process_device) 
        self.EGNN = eGNN.EGNN(in_node_nf, hidden_nf, out_node_nf, in_edge_nf=0, 
                              device=self.gnn_device , act_fn=nn.SiLU(), n_layers=n_layers, residual=residual, attention=attention, normalize=normalize, tanh=tanh)

    def _sc_embed(self, t): #sc embedding of the noising timestep as proposed by the DDPM (https://arxiv.org/pdf/2006.11239)
        half_dim = self.sc_dim//2
        emb = math.log(10000) / (half_dim - 1)
        emb = torch.exp(torch.arange(half_dim,) * -emb)
        emb = t[:, None] * emb[None, :]
        emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
        return emb

    def sc_pos_embed(self, h): # Sequence position embedding as used in ESMFold (https://www.biorxiv.org/content/10.1101/2022.07.20.500902v1)
        t = torch.arange(0,h.shape[0])+1
        half_dim = h.shape[-1]//2
        emb = math.log(10000) / (half_dim - 1)
        emb = torch.exp(torch.arange(half_dim,) * -emb)
        emb = t[:, None] * emb[None, :]
        emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
        h += emb
        return h
        
    def sample(self, seq_id): #Sampling process as defined by the DDPM paper
        #seq_id is a 1D tensor
        x_t = torch.randn((seq_id.shape[0], 4, 3)).to(self.gnn_device) #N x 4 x 3
        
        #This remains constant, no need to recompute
        n_aa = seq_id.shape[0]
        seq_id=seq_id.to(self.gnn_device)
        h = self.embedding(seq_id)
        h = h.to(self.process_device)
        h = self.sc_pos_embed(h) 

        #Generate Edges
        row =torch.arange(0,n_aa).repeat_interleave(n_aa).to(self.gnn_device)
        col =torch.arange(0,n_aa).repeat(n_aa).to(self.gnn_device)
        edges_id=(row,col)
        
        for t in range(self.T-1, -1, -1):  
            if t > 0:
                z = torch.normal(mean=torch.zeros_like(x_t),std=torch.ones_like(x_t))
            else:
                z = torch.zeros_like(x_t)
            sc_emb = self._sc_embed(torch.tensor([t+1])).repeat(n_aa, 1)
            h_t = torch.cat((h,sc_emb),dim=-1)
            h_t = h_t.to(self.gnn_device)
            pred_h_eps, pred_x_eps= self.EGNN(h_t, x_t, edges_id, edge_attr=None)
            if t != 0:
                a_mul = torch.prod(1-self.diffusion.beta[0:t+1])
                a_mul_1 = torch.prod(1-self.diffusion.beta[0:t])
            else:
                a_mul = torch.prod(1-self.diffusion.beta[t])
                a_mul_1 = 0
            a = 1-self.diffusion.beta[t]
            #N.B. Vars = mul(Bt) for Xo ~ N(0,I) if Xo is a predetermined point the defined vars is prefered. See p.g.,3 of the DDPM paper
            std = ((1-a_mul_1)/(1-a_mul))*self.diffusion.beta[t]
            std= torch.sqrt(std)
            x_t = (a**(-0.5))*(x_t- ((1-a)/(torch.sqrt(1-a_mul)))*pred_x_eps)+std*z
        return x_t #Multiply the output by the global std
        
    def forward(self, x, seq_id, edges_id): #Takes the backbone coordinates, sequence id, edges, and adds noise prior to passing it to the eGNN
        assert x.shape[0]==seq_id.shape[0], "Sequence length and coordinate first dimension is not of the same shape!"

        n_aa = seq_id.shape[0]
        seq_id=seq_id.to(self.gnn_device)
        h = self.embedding(seq_id)
        h = h.to(self.process_device)
        h = self.sc_pos_embed(h)
        
        t2 = torch.randint(0, high=self.T, size=(1,))
        sc_emb = self._sc_embed(t2+1).repeat(n_aa, 1)
        x_perturbed, x_eps, h_perturbed, h_eps = self.diffusion(x,h,t2)
        x_perturbed=x_perturbed.to(self.gnn_device)
        h = torch.cat((h,sc_emb),dim=-1)
        h = h.to(self.gnn_device)
        #h_perturbed=h_perturbed.to(self.gnn_device)
        pred_h_eps, pred_x_eps= self.EGNN(h, x_perturbed, edges_id, edge_attr=None)
        return pred_h_eps, h_eps, pred_x_eps, x_eps


In [8]:
#Params
#Num of clean proteins 31065
steps = 20000
batch_size = 500
with open(list_path) as file:
    lines = [line.rstrip() for line in file]
np.random.seed(seed=17)
np.random.shuffle(lines)
train_list = np.array(lines[:-len(lines)//5])
val_list = np.array(lines[-len(lines)//5:])
np.savetxt("train_list.csv", train_list, delimiter=",", fmt='%s')
np.savetxt("val_list.csv", val_list, delimiter=",", fmt='%s')
#Instantiate Models 
embed_dim = 64
sc_dim = 32
in_node_nf = embed_dim + sc_dim
hidden_nf =128
out_node_nf = 64

model = prot_eGNN("cpu", "cuda" , embed_dim, sc_dim, in_node_nf, hidden_nf, out_node_nf,
                 800, b_initial=0.0001, b_final=0.02, 
                 in_edge_nf=0, act_fn=nn.SiLU(), n_layers=5, residual=True, attention=True, normalize=False, tanh=False)
model=model.to("cuda")

In [9]:
#Define loss here
denoising_loss = nn.MSELoss()
optimizer = torch.optim.AdamW(model.parameters(), weight_decay=0.01, lr=0.0001)

Attempted different training loss as when training naively we observed some training loss instability. The training loss instability is solved after data normalization 

In [None]:
#Training Loop 1 (Training by batches)
g_pdb_dir="/mnt/rna01/nico/dynamics_project/learn/eGNN/generated_pdb/"
model.train()
agg_loss = torch.tensor([0.],dtype=torch.float32).to("cuda")
for i in tqdm(range(steps)):
    pdb_path = pdb_dir+np.random.choice(train_list)
    try:
        frames,seq,edges=input_generation(pdb_path)
    except:
        continue
    CoM = torch.mean(frames.flatten(end_dim=-2),dim=-2, keepdim=True)
    frames = frames - CoM[None,:]
    frames = frames/data_std
    i_seq = data_read.encode(seq)
    
    pred_h_eps, h_eps, pred_x_eps, x_eps=model(frames, i_seq, edges)

    
    x_eps = x_eps.to("cuda")
    loss = denoising_loss(pred_x_eps,x_eps)
    agg_loss +=loss
    if (i+1)%500 == 0: #Print loss
        print(f"{i+1} Step : {loss}")
        print(f"{i+1} Step : {agg_loss/500}")
        print(type(agg_loss))
        agg_loss = torch.tensor([0.],dtype=torch.float32).to("cuda")
        torch.save({
            'epoch': i+1,
            'model_state_dict': model.state_dict(),
            'step_loss': agg_loss
            }, f'./model_weights/prot_eGNN_{i+1}')
        model.eval()
        if i>9999: #Generate some structures after training and save them as PDB files
            for j in val_list[:30]:
                v_frames,v_seq,v_edges=input_generation(pdb_dir+j)
                del v_frames
                del v_edges
                gc.collect()
                v_i_seq = data_read.encode(v_seq)
                with torch.no_grad():
                    coords=model.sample(v_i_seq)
                    coords=coords*data_std
                res = generate_res_object(v_seq,coords)
                generate_pdb(g_pdb_dir+f"pred_{j}_model_{i+1}",res)  
        model.train()
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

  1%|█                                                                              | 257/20000 [00:10<13:39, 24.09it/s]

In [10]:
# Training Loop 2 (Training on each step)
g_pdb_dir="/mnt/rna01/nico/dynamics_project/learn/eGNN/generated_pdb/"
model.train()
agg_loss = torch.tensor([0.],dtype=torch.float32).to("cuda")
for i in tqdm(range(steps)):
    pdb_path = pdb_dir+np.random.choice(train_list)
    try:
        frames,seq,edges=input_generation(pdb_path)
    except:
        continue
    if len(seq)>400:continue
    CoM = torch.mean(frames.flatten(end_dim=-2),dim=-2, keepdim=True)
    frames = frames - CoM[None,:]
    frames = frames/data_std
    i_seq = data_read.encode(seq)
    
    pred_h_eps, h_eps, pred_x_eps, x_eps=model(frames, i_seq, edges)

    
    x_eps = x_eps.to("cuda")
    loss = denoising_loss(pred_x_eps,x_eps)
    loss.backward()
    agg_loss +=loss
    if (i+1)%batch_size ==0:
        optimizer.step()
        optimizer.zero_grad()
    if (i+1)%500 == 0:
        print(f"{i+1} Step : {loss}")
        print(f"{i+1} Step : {agg_loss/500}")
        print(type(agg_loss))
        torch.save({
            'epoch': i+1,
            'model_state_dict': model.state_dict(),
            'step_loss': agg_loss[0]
            }, f'./model_weights/prot_eGNN_{i+1}')
        agg_loss = torch.tensor([0.],dtype=torch.float32).to("cuda")
        if i>10000:
            model.eval()
            for j in val_list[:30]:
                v_frames,v_seq,v_edges=input_generation(pdb_dir+j)
                del v_frames
                del v_edges
                gc.collect()
                v_i_seq = data_read.encode(v_seq)
                with torch.no_grad():
                    coords=model.sample(v_i_seq)
                    coords=coords*data_std
                res = generate_res_object(v_seq,coords)
                generate_pdb(g_pdb_dir+f"pred_{j}_model_{i+1}",res)  
            model.train()

  2%|█▉                                                                             | 498/20000 [00:17<09:24, 34.53it/s]

500 Step : 0.5745052695274353
500 Step : tensor([0.5515], device='cuda:0', grad_fn=<DivBackward0>)
<class 'torch.Tensor'>


  5%|███▉                                                                          | 1003/20000 [00:35<11:56, 26.53it/s]

1000 Step : 1.7658617496490479
1000 Step : tensor([0.4631], device='cuda:0', grad_fn=<DivBackward0>)
<class 'torch.Tensor'>


  8%|█████▊                                                                        | 1503/20000 [00:51<08:52, 34.75it/s]

1500 Step : 0.1382541060447693
1500 Step : tensor([0.4537], device='cuda:0', grad_fn=<DivBackward0>)
<class 'torch.Tensor'>


 10%|███████▊                                                                      | 2003/20000 [01:08<11:30, 26.06it/s]

2000 Step : 0.005201550666242838
2000 Step : tensor([0.4711], device='cuda:0', grad_fn=<DivBackward0>)
<class 'torch.Tensor'>


 12%|█████████▊                                                                    | 2500/20000 [01:25<09:21, 31.16it/s]

2500 Step : 0.17913168668746948
2500 Step : tensor([0.4584], device='cuda:0', grad_fn=<DivBackward0>)
<class 'torch.Tensor'>


 15%|███████████▋                                                                  | 3004/20000 [01:42<08:33, 33.09it/s]

3000 Step : 0.4811536371707916
3000 Step : tensor([0.4154], device='cuda:0', grad_fn=<DivBackward0>)
<class 'torch.Tensor'>


 18%|█████████████▋                                                                | 3504/20000 [01:58<09:21, 29.38it/s]

3500 Step : 0.47320646047592163
3500 Step : tensor([0.4075], device='cuda:0', grad_fn=<DivBackward0>)
<class 'torch.Tensor'>


 20%|███████████████▌                                                              | 4005/20000 [02:16<10:39, 25.02it/s]

4000 Step : 0.008586352691054344
4000 Step : tensor([0.4290], device='cuda:0', grad_fn=<DivBackward0>)
<class 'torch.Tensor'>


 23%|█████████████████▌                                                            | 4505/20000 [02:32<08:10, 31.59it/s]

4500 Step : 0.003172078635543585
4500 Step : tensor([0.4180], device='cuda:0', grad_fn=<DivBackward0>)
<class 'torch.Tensor'>


 25%|███████████████████▌                                                          | 5003/20000 [02:50<10:29, 23.82it/s]

5000 Step : 0.031799621880054474
5000 Step : tensor([0.3630], device='cuda:0', grad_fn=<DivBackward0>)
<class 'torch.Tensor'>


 28%|█████████████████████▍                                                        | 5502/20000 [03:07<08:19, 29.05it/s]

5500 Step : 0.0016593351028859615
5500 Step : tensor([0.4137], device='cuda:0', grad_fn=<DivBackward0>)
<class 'torch.Tensor'>


 30%|███████████████████████▍                                                      | 6003/20000 [03:24<07:52, 29.61it/s]

6000 Step : 0.0015464074676856399
6000 Step : tensor([0.3718], device='cuda:0', grad_fn=<DivBackward0>)
<class 'torch.Tensor'>


 33%|█████████████████████████▎                                                    | 6502/20000 [03:42<08:01, 28.02it/s]

6500 Step : 1.1315107345581055
6500 Step : tensor([0.3704], device='cuda:0', grad_fn=<DivBackward0>)
<class 'torch.Tensor'>


 35%|███████████████████████████▎                                                  | 7004/20000 [03:59<08:44, 24.77it/s]

7000 Step : 0.0013609816087409854
7000 Step : tensor([0.3818], device='cuda:0', grad_fn=<DivBackward0>)
<class 'torch.Tensor'>


 38%|█████████████████████████████▎                                                | 7501/20000 [04:15<08:15, 25.22it/s]

7500 Step : 0.04701844975352287
7500 Step : tensor([0.3687], device='cuda:0', grad_fn=<DivBackward0>)
<class 'torch.Tensor'>


 40%|███████████████████████████████▏                                              | 8002/20000 [04:32<07:03, 28.31it/s]

8000 Step : 0.03780105337500572
8000 Step : tensor([0.5123], device='cuda:0', grad_fn=<DivBackward0>)
<class 'torch.Tensor'>


 43%|█████████████████████████████████▏                                            | 8503/20000 [04:48<06:28, 29.61it/s]

8500 Step : 0.07415009289979935
8500 Step : tensor([0.3835], device='cuda:0', grad_fn=<DivBackward0>)
<class 'torch.Tensor'>


 45%|███████████████████████████████████                                           | 9006/20000 [05:05<07:18, 25.09it/s]

9000 Step : 0.03502056002616882
9000 Step : tensor([0.3481], device='cuda:0', grad_fn=<DivBackward0>)
<class 'torch.Tensor'>


 48%|█████████████████████████████████████                                         | 9503/20000 [05:22<05:57, 29.39it/s]

9500 Step : 1.7008038759231567
9500 Step : tensor([0.3604], device='cuda:0', grad_fn=<DivBackward0>)
<class 'torch.Tensor'>


 50%|██████████████████████████████████████▌                                      | 10003/20000 [05:38<04:54, 33.96it/s]

10000 Step : 0.1585865169763565
10000 Step : tensor([0.3924], device='cuda:0', grad_fn=<DivBackward0>)
<class 'torch.Tensor'>


 52%|████████████████████████████████████████▍                                    | 10497/20000 [05:55<04:42, 33.64it/s]

10500 Step : 0.25409287214279175
10500 Step : tensor([0.5464], device='cuda:0', grad_fn=<DivBackward0>)
<class 'torch.Tensor'>


 52%|████████████████████████████████████████▍                                    | 10499/20000 [07:16<06:35, 24.04it/s]


KeyboardInterrupt: 

# Debugging Attempt

In [13]:
#Params
#Num of clean proteins 31065
steps = 1001
batch_size = 500
with open(list_path) as file:
    lines = [line.rstrip() for line in file]
np.random.seed(seed=17)
np.random.shuffle(lines)
train_list = np.array(lines[:-len(lines)//5])
val_list = np.array(lines[-len(lines)//5:])
np.savetxt("train_list.csv", train_list, delimiter=",", fmt='%s')
np.savetxt("val_list.csv", val_list, delimiter=",", fmt='%s')
#Instantiate Models 
embed_dim = 64
sc_dim = 32
in_node_nf = embed_dim + sc_dim
hidden_nf =128
out_node_nf = 64

model = prot_eGNN("cpu", "cuda" , embed_dim, sc_dim, in_node_nf, hidden_nf, out_node_nf,
                 800, b_initial=0.0001, b_final=0.02, 
                 in_edge_nf=0, act_fn=nn.SiLU(), n_layers=5, residual=True, attention=True, normalize=False, tanh=False)
model=model.to("cuda")

In [None]:
#Define loss here
denoising_loss = nn.MSELoss()
optimizer = torch.optim.AdamW(model.parameters(), weight_decay=0.01, lr=0.0001)

Tried to overfit on one model to see if the model learns correctly. Similar loss is observed as when training on a large dataset implying that there is some bugs with the model. 

In [None]:
# Overfitting Attempt
g_pdb_dir="/mnt/rna01/nico/dynamics_project/learn/eGNN/generated_pdb/"
model.train()
agg_loss = torch.tensor([0.],dtype=torch.float32).to("cuda")
train_list = [train_list[0]]
print(train_list)
for i in tqdm(range(steps)):
    pdb_path = pdb_dir+np.random.choice(train_list)
    try:
        frames,seq,edges=input_generation(pdb_path)
    except:
        continue
    if len(seq)>400:continue
    CoM = torch.mean(frames.flatten(end_dim=-2),dim=-2, keepdim=True)
    frames = frames - CoM[None,:]
    frames = frames/data_std
    i_seq = data_read.encode(seq)
    
    pred_h_eps, h_eps, pred_x_eps, x_eps=model(frames, i_seq, edges)

    
    x_eps = x_eps.to("cuda")
    loss = denoising_loss(pred_x_eps,x_eps)
    loss.backward()
    agg_loss +=loss
    if (i+1)%batch_size ==0:
        optimizer.step()
        optimizer.zero_grad()
    if (i+1)%500 == 0:
        print(f"{i+1} Step : {loss}")
        print(f"{i+1} Step : {agg_loss/500}")
        print(type(agg_loss))
        torch.save({
            'epoch': i+1,
            'model_state_dict': model.state_dict(),
            'step_loss': agg_loss[0]
            }, f'./model_weights/prot_eGNN_{i+1}')
        agg_loss = torch.tensor([0.],dtype=torch.float32).to("cuda")
        if i>900:
            model.eval()
            for j in train_list:
                v_frames,v_seq,v_edges=input_generation(pdb_dir+j)
                del v_frames
                del v_edges
                gc.collect()
                v_i_seq = data_read.encode(v_seq)
                with torch.no_grad():
                    coords=model.sample(v_i_seq)
                    coords=coords*data_std
                res = generate_res_object(v_seq,coords)
                generate_pdb(g_pdb_dir+f"pred_{j}_model_{i+1}",res)  
            model.train()

['4h7wA00']


 50%|████████████████████████████████████████▎                                       | 504/1001 [00:19<00:21, 23.21it/s]

500 Step : 0.010035381652414799
500 Step : tensor([0.4285], device='cuda:0', grad_fn=<DivBackward0>)
<class 'torch.Tensor'>


 68%|██████████████████████████████████████████████████████                          | 676/1001 [00:26<00:11, 27.39it/s]