Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

The losses remain the same #8

Closed
18972441546 opened this issue May 9, 2022 · 5 comments
Closed

The losses remain the same #8

18972441546 opened this issue May 9, 2022 · 5 comments

Comments

@18972441546
Copy link

import torch

class CNNCRF(torch.nn.Module):
"""
Simple CNN-CRF model
"""
def init(self, cnn, crf):
super().init()
self.cnn = cnn
self.crf = crf

def forward(self, x):
    """
    x is a batch of input images
    """
    logits = self.cnn(x)
    logits = self.crf(x, logits)
    return logits

Create a CNN-CRF model from given cnn and crf

This is a PyTorch module that can be used in a usual way

model = CNNCRF(cnn, crf)

First I train unET and save the model, then I load the trained UNET model and train UNET and CRFS. I found the loss stuck at 0.693147. Do you have any suggestions?

@netw0rkf10w
Copy link
Owner

Could you tell me how you instantiated the CRF? Thanks.

@18972441546
Copy link
Author

My code is shown below, but this is direct training, but loss is always constant.

import CRF
import torch
from unet_model import *
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
params = CRF.FrankWolfeParams(scheme='fixed', # constant stepsize
stepsize=1.0,
regularizer='l2',
lambda_=1.0, # regularization weight
lambda_learnable=False,
x0_weight=0.5, # useful for training, set to 0 if inference only
x0_weight_learnable=False)

crf = CRF.DenseGaussianCRF(classes=1,
alpha=160,
beta=0.05,
gamma=3.0,
spatial_weight=1.0,
bilateral_weight=1.0,
compatibility=1.0,
init='potts',
solver='fw',
iterations=5,
params=params)
class CNNCRF(torch.nn.Module):
"""
Simple CNN-CRF model
"""
def init(self, cnn, crf):
super().init()
self.cnn = cnn
self.crf = crf
def forward(self, x):
"""
x is a batch of input images
"""
logits = self.cnn(x)#get the tensor of cnn
logits = self.crf(x, logits)
return logits

Create a CNN-CRF model from given cnn and crf

This is a PyTorch module that can be used in a usual way

cnn=UNet()
UnetCrfs = CNNCRF(cnn, crf).to(device)

if name == 'main':
cnn=UNet(1,1)# 1 is symbol of the in_channel of cnn. 3 is symbol of the out_inchnnel of cnn
model = CNNCRF(cnn, crf).to(device)
data_path='./data'
isbi_dataset = ISBI_Loader(data_path)
train_loader = torch.utils.data.DataLoader(dataset=isbi_dataset,
batch_size=1,
shuffle=False)
for epoch in range(epochs):
net.train()
for image, label in train_loader:
optimizer.zero_grad()
image = image.to(device=device, dtype=torch.float32)
label = label.to(device=device, dtype=torch.float32)
pred = net(image)
loss = criterion(pred, label)
# print('{}/{}:Loss/train'.format(epoch + 1, epochs), loss.item())
if loss < best_loss:
best_loss = loss
torch.save(net.state_dict(), 'best_model.pth')
loss.backward()
optimizer.step()
pbar.update(1)

@netw0rkf10w
Copy link
Owner

Are you sure that training is successful without the CRF? Could you try replacing the line UnetCrfs = CNNCRF(cnn, crf).to(device) with UnetCrfs = cnn.to(device) and see what happens?

@18972441546
Copy link
Author

thanks,it is ok.

@netw0rkf10w
Copy link
Owner

Great. Do not hesitate to let me know if you encounter any issues.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants