Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

I wonder if i could get some help with my own RGB input #11

Closed
GustavoCamargoRL opened this issue Aug 12, 2019 · 15 comments
Closed

I wonder if i could get some help with my own RGB input #11

GustavoCamargoRL opened this issue Aug 12, 2019 · 15 comments

Comments

@GustavoCamargoRL
Copy link

I'm trying to test with my own inputs, but i'm not quite sure how to do it.
I thought it was in the dataloader.py code, but when i tried debugging it apparently this class is for the NYU dataset right?
If you could explain how to proper do it, it will be very helpful.

Thanks!

@fangchangma
Copy link
Collaborator

Hi. To test with your own images, simply read them in as a 4-dimensional pytorch floating point tensor of size 1x224x224x3. The raw rgb values [0,255] should be divided by a constant factor 255.0, such that all pixel values fall in the range of [0, 1].

@GustavoCamargoRL
Copy link
Author

GustavoCamargoRL commented Aug 19, 2019

I tried doing this but i've got this error "RuntimeError: Given groups=1, weight of size 32 3 3 3, expected input[1, 244, 244, 3] to have 3 channels, but got 244 channels instead". My image is 244x244 and i'm giving the right format, as you can see here:

import matplotlib.pyplot as plt
import numpy as np
img = plt.imread("img.jpg")/255.
img.shape
(244, 244, 3)
img = np.expand_dims(img, axis=0)
img.shape
(1, 244, 244, 3)
i = torch.from_numpy(img)
Traceback (most recent call last):
File "", line 1, in
NameError: name 'torch' is not defined
import torch
i = torch.from_numpy(img)
i.shape
torch.Size([1, 244, 244, 3])

So i don't know why this error is occurring, if i'm giving the exact same format.

@dwofk
Copy link
Owner

dwofk commented Aug 19, 2019

The PyTorch conv2d function assumes inputs to be in 'NCHW' format, meaning that the tensor you feed into the network should be of shape [1, 3, 224, 224]. From your code snippet, you may be using a 'NHWC' format -- try permuting the tensor dimensions to change to 'NCHW'.

@fangchangma
Copy link
Collaborator

My image is 244x244 and i'm giving the right format

Also, the correct image size is 224 x 224, not 244 x 244

@GustavoCamargoRL
Copy link
Author

Oh thanks! it worked, just one more problem, the results were these :
Figure_1
img

I placed my input code in the "args.evaluate" if condition, and then saved my results in a ply file, so my question is if there is any pos processing missing for the correct prediction of the depth map that i forgoted to do, or it just didn't work for this image.

@fangchangma
Copy link
Collaborator

Have you divided the input RGB values by 255.0, as in this line?

rgb_np = np.asfarray(rgb_np, dtype='float') / 255

@GustavoCamargoRL
Copy link
Author

GustavoCamargoRL commented Aug 19, 2019

Not exactly like this.
This is my input code :

    img = plt.imread("img.jpg")/255.  #normalization
    img = np.reshape(img, (3, 224, 224))
    img = np.expand_dims(img, axis=0)
    print(img.shape)
    with torch.no_grad():
        pred = model(torch.from_numpy(img).float().cuda())
        np.save('pred.npy', pred.cpu())
    
    print(pred)
    import sys
    sys.exit(0)

@fangchangma
Copy link
Collaborator

img = np.reshape(img, (3, 224, 224))

I believe it should be permutation of dimensions here, rather than reshaping (which breaks the data ordering). Please try img = np.transpose(img, (2,0,1)) and see if it makes a difference.

@GustavoCamargoRL
Copy link
Author

It worked much better!
I will try better results with other images. Thanks for the help!
Figure_2

@mathmax12
Copy link

mathmax12 commented Oct 29, 2020

Thanks for the work @dwofk @fangchangma.
I am trying the same thing as @GustavoCamargoRL did.

   while True:  
        image_cuda = torch.from_numpy(img).float().cuda()
        pred = 0
        print(pred)
        with torch.no_grad():
            pred = model(image_cuda)
            #np.save('pred.npy', pred.cpu())
        print(pred)

The output from the first iteration looks good. But at each iteration, the output is different from the output of other iterations even with the same input image (See below pic).
If I kill the thread and execute the code each time at the first iteration I will get the same output.
image

I print the pred values and find that it does differ from the previous iteration even with the same input image and the same model.
image

Is there anything I missed for using the model?

@mathmax12
Copy link

@GustavoCamargoRL Do you have the same issue?

@LulaSan
Copy link

LulaSan commented May 11, 2021

@mathmax12 Have you done this by using tvm apache?

@mathmax12
Copy link

@LulaSan It turns out that this caused by the tvm . the latest tvm solved this

@LulaSan
Copy link

LulaSan commented May 13, 2021

@mathmax12 Ok thank you, can I ask you how do you visualize the results? By using their code visualize.py?

@mathmax12
Copy link

You can save the results as https://github.com/dwofk/fast-depth/blob/master/main.py#L98
or using cv2.imshow() to display

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

5 participants