Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

reproducibility issue of DGL #3302

Closed
xhran2010 opened this issue Aug 30, 2021 · 31 comments
Closed

reproducibility issue of DGL #3302

xhran2010 opened this issue Aug 30, 2021 · 31 comments
Assignees
Labels
topic: PyTorch Issues only in PyTorch backend
Projects

Comments

@xhran2010
Copy link

xhran2010 commented Aug 30, 2021

馃悰 Bug

I used the dgl to utilize GAT-like network. And I fixed the seed of python, numpy, pytorch and dgl for reproducibility. However, the results are still not deterministic and the varied range is very large. Detailedly, I used the following code for fixing seed:

def set_seeds(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    dgl.seed(seed)

To Reproduce

My GAT-like networks are like:

class GATLayer(nn.Module):
    def __init__(self, hidden_size, alpha, beta, gamma=0.2, dropout=0.6):
        super().__init__()
        self.gamma = gamma
        self.alpha = alpha
        self.beta = beta

        self.hidden_size = hidden_size

        self.W_fc = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
        self.attn_fc = nn.Linear(2 * hidden_size, 1, bias=False)
        self.leakyrelu = nn.LeakyReLU(self.gamma)
    
    def edge_attention(self, edges):
        z2 = torch.cat([edges.src['emb_attn'], edges.dst['emb_attn']], dim=1) # N x 2h
        a = self.attn_fc(z2) # N x 1
        return {'e': self.leakyrelu(a)} # N x 1

    def message_func(self, edges):
        # message UDF for equation (3) & (4)
        return {'z': edges.src['emb_crf'], 'e': edges.data['e']}
    
    def reduce_func(self, nodes):
        alpha = torch.softmax(nodes.mailbox['e'], dim=1) # N x 1
        # equation (4)
        h = torch.sum(alpha * nodes.mailbox['z'], dim=1) # N x h -> 1 x h
        return {'h': h}

    def forward(self, embedding_input, h_input, graph):
        dv = 'cuda' if embedding_input.is_cuda else 'cpu'

        z = self.W_fc(h_input)
        graph.ndata['emb_crf'] = h_input
        graph.ndata['emb_attn'] = z
        graph.apply_edges(self.edge_attention)
        graph.update_all(self.message_func, self.reduce_func)
        
        gat_output = graph.ndata.pop('h')
        output = (self.alpha * embedding_input + self.beta * gat_output) / (self.alpha + self.beta)

        return output

Expected behavior

Environment

  • DGL Version (e.g., 1.0): 0.6.x
  • Backend Library & Version (e.g., PyTorch 0.4.1, MXNet/Gluon 1.3):Pytorch 1.9.0
  • OS (e.g., Linux): Linux
  • How you installed DGL (conda, pip, source): pip
  • Build command you used (if compiling from source):
  • Python version: 3.7.9
  • CUDA/cuDNN version (if applicable): 10.2
  • GPU models and configuration (e.g. V100): P40
  • Any other relevant information:

Additional context

@VoVAllen
Copy link
Collaborator

Can you provide the full code for us to reproduce? Also could you try turn off dropout to see how large is the variance?

@xhran2010
Copy link
Author

Due to the data privacy issue, I can't provide the full code. sorry for that. plus, I didn't turn on dropout.

@xhran2010
Copy link
Author

are there any possible reasons?

@VoVAllen
Copy link
Collaborator

@xhran2010 The computations in your code with UDFs are all torch operations and there's no direct dgl computation here. DGL only buckets the edge operation together, which should be deterministic. It's unclear for me what's the possible reason for now. Are you running it on CPU or GPU? Also could you try dgl 0.7?

@xhran2010
Copy link
Author

@VoVAllen my task is not a typical node-level/graph-level task. Instead, In each batch, I sampled a subgraph using dgl.node_subgraph() and sent nodes in the subgraph into the graph network. Then I used these updated embeddings of the nodes in the subgraph for downstream tasks. Could this bring any uncertainty?

@xhran2010
Copy link
Author

I run my code on GPU, and I tried dgl 0.7 but it still didn't work.

@VoVAllen
Copy link
Collaborator

Sampling is not deterministic on dgl. Unless you turn off the OpenMP by set OMP_NUM_THREADS=1 in environment variable

@xhran2010
Copy link
Author

I sampled node IDs by numpy, and provided dgl.node_subgraph() with deterministic node indexes.

@xhran2010
Copy link
Author

xhran2010 commented Aug 31, 2021

@VoVAllen specifically, the pseudocode is like:

def forward(self):
    neg_items = self.neg_sampling() # sampling by np.random.choice()
    item_involve = torch.cat([pos_items, neg_items])
    sub_adj = self.adj.subgraph(item_involve)

    node_repr = self.embedding(item_involve)
    node_repr = self.graph_model(node_repr, sub_pp_adj) # utilized by DGL
    loss = loss_func(node_repr)

@VoVAllen VoVAllen self-assigned this Aug 31, 2021
@jermainewang jermainewang added this to Issue in triage in DGL Tracker via automation Sep 6, 2021
@jermainewang jermainewang added the bug:unconfirmed May be a bug. Need further investigation. label Sep 6, 2021
@VoVAllen
Copy link
Collaborator

VoVAllen commented Sep 6, 2021

It's hard to tell which part results in non-deterministic. Could you provide a minimal reproducible code?

@nv-dlasalle
Copy link
Collaborator

@VoVAllen Why is sampling non-deterministic with the same number of threads? Don't we have a separate RNG per thread?

@BarclayII
Copy link
Collaborator

BarclayII commented Sep 24, 2021

@VoVAllen Why is sampling non-deterministic with the same number of threads? Don't we have a separate RNG per thread?

Even if the threads have separate RNGs, the thread with the same RNG can still be assigned to different nodes between runs (unless you control the scheduling somehow), so in the end you may still have different neighbors sampled for the same node between runs.

I don't think this is related to the OP though.

@avikpal00
Copy link

I am facing similar issue using GATConv method from dgl.nn, where the results I'm getting on training experiments vary over a wide range (approx. 0.40 - 0.83 precision). Has there been any update on this issue which could make this process more deterministic?

A sample reference for what I'm trying out-

class GATCN(nn.Module):
    def __init__(self, in_feats, h_feats, n_heads, num_classes):
        super(GATCN, self).__init__()
        self.gat1 = GATConv(in_feats, h_feats, num_heads=n_heads)
        self.gat2 = GATConv(h_feats*n_heads, num_classes, num_heads=1)
    
    def forward(self, g, in_feat):
        # Apply graph convolution and activation.
        bs = in_feat.shape[0]
        h = F.relu(self.gat1(g, in_feat))
        h = h.reshape(bs, -1)
        h = F.relu(self.gat2(g, h))
        h = h.reshape(bs, -1)
        return h

def train(g, model, epochs=100, learning_rate=0.01, get_intermediate_embeddings=False):
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    best_val_acc = 0
    best_test_acc = 0

    features = g.ndata['feat']
    labels = g.ndata['label']
    train_mask = g.ndata['train_mask']
    val_mask = g.ndata['val_mask']
    test_mask = g.ndata['test_mask']
    all_logits = []
    all_hidden_embeds = []
    for e in range(epochs):
        # Forward
        logits = model(g, features)
        
        # we save the logits for visualization later
        all_logits.append(logits.detach())

        # Compute prediction
        pred = logits.argmax(1)

        # Compute loss
        # Note that you should only compute the losses of the nodes in the training set.
        loss = F.cross_entropy(logits[train_mask], labels[train_mask])

        # Compute accuracy on training/validation/test
        train_acc = (pred[train_mask] == labels[train_mask]).float().mean()
        val_acc = (pred[val_mask] == labels[val_mask]).float().mean()
        test_acc = (pred[test_mask] == labels[test_mask]).float().mean()

        # Save the best validation accuracy and the corresponding test accuracy.
        if best_val_acc < val_acc:
            best_val_acc = val_acc
            best_test_acc = test_acc

        # Backward
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if e % 5 == 0:
            print('In epoch {}, loss: {:.3f}, val acc: {:.3f} (best {:.3f}), test acc: {:.3f} (best {:.3f})'.format(
                e, loss, val_acc, best_val_acc, test_acc, best_test_acc))
    
    return all_logits, all_hidden_embeds


model = GATCN(g.ndata['feat'].shape[1], 16, 5, dataset.num_classes)
trained_embeddings, hidden_layer_embeds = train(g, model, epochs=100, learning_rate=0.01, get_intermediate_embeddings=False)

@github-actions
Copy link

This issue has been automatically marked as stale due to lack of activity. It will be closed if no further activity occurs. Thank you

@VoVAllen VoVAllen added topic: PyTorch Issues only in PyTorch backend and removed bug:unconfirmed May be a bug. Need further investigation. stale-issue labels Feb 17, 2022
@duncanriach
Copy link

duncanriach commented Mar 3, 2022

@xhran2010, instead of setting torch.backends.cudnn.deterministic = True, you should instead be calling torch.use_deterministic_algorithms. If the source of nondeterminism is a PyTorch op (other than cuDNN ops), then this will either make it function deterministically, or throw an exception, so you'll know what op is getting in the way of reproducibility.

You might also want to take a look at the general instructions for reproducibility in PyTorch. There is also some useful info for PyTorch reproducibility here, which I maintain.

@decoherencer
Copy link
Contributor

I tried dataloader for batching with a MultiLayerFullNeighborSampler, and the variance is still present even after I removed dropouts and set seeds. I also tried with no shuffle in train data and single node in batch, but still can't reproduce the results.

with torch.use_deterministic_algorithms it pointed usage of linear layer on features, F.linear(input, self.weight, self.bias):

RuntimeError: Deterministic behavior was enabled with either torch.use_deterministic_algorithms(True) or at::Context::setDeterministicAlgorithms(true), but this operation is not deterministic because it uses CuBLAS and you have CUDA >= 10.2. To enable deterministic behaviour, in this case, you must set an environment variable before running your PyTorch application: CUBLAS_WORKSPACE_CONFIG=:4096:8 or CUBLAS_WORKSPACE_CONFIG=:16:8. For more information, go to https://docs.nvidia.com/cuda/cublas/index.html#cublasApi_reproducibility

I set it to 4096:8, but still the results are varying. I used minbatching code from here and gat from here for node classification on cora.

@jermainewang
Copy link
Member

Could you provide your script so we could take a look? @duncanriach

@duncanriach
Copy link

duncanriach commented Mar 7, 2022

@decoherencer, wonderful. So, in your particular program, you have ruled-out PyTorch ops as being the only source of the nondeterminism you're seeing, and you have now also eliminated any nondeterminism that was previously due to PyTorch ops. Have you followed @VoVAllen's instructions for making DGL sampling deterministic?

Note to self: the creator of this issue, @xhran2010, is not using DGL's sampling, so we're still waiting for @xhran2010 to report on what happens when torch.use_deterministic_algorithms(True) is used with his program.

@jermainewang, you have asked for me to provide a script. I presume that you meant to ask for a script from @decoherencer.

@rickyxume
Copy link

GraphSAGE-sampling meet same problem on DGL 0.8.0post1. and PyTorch 1.11.0

I have tried the following methods:

random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
dgl.seed(args.seed)
if args.device != 'cpu':
    torch.cuda.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)  # multi-GPU
torch.use_deterministic_algorithms(True)
# torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

The pure PyTorch environment is determinism, but not fine in DGL.

So, excluding the problem of PyTorch, I think the problem of non-reproducibility may be that the sampler or dataloader of DGL is nondeterminism.

@duncanriach
Copy link

@rickyxume, you have ensured that PyTorch is operating deterministically. Have you followed @VoVAllen's instructions for making DGL sampling deterministic?

@rickyxume
Copy link

@duncanriach , I also tried more ways to ensure determinism, Including @VoVAllen's instructions, but all failed.

os.environ['OMP_NUM_THREADS'] = '1'
# export OMP_NUM_THREADS=1
os.environ['MKL_NUM_THREADS'] = '1'
torch.set_num_threads(1)

Running next(iter(train_dataloader)) twice separately will give different results.

@duncanriach
Copy link

duncanriach commented Mar 25, 2022

@rickyxume, I have many points to make:

  • First of all, the way you're getting an iterator from train_dataloader raises a red flag for me. An iterator usually has some state inside it that advances with each iteration. Usually, when we get an iterator it's because we want something that generates something different on each iteration. A reproducible "reset and train from scratch" function should simply be runnable and, if deterministic, should produce exactly the same result on each run. One could build an iterator that doesn't iterate, but it would add pointless additional complexity. The fact that you're getting different results on each iteration of an iterator is pretty much confirmation that you do in fact have an iterator.
  • Regarding "but all failed": this is not necessarily true. There may be multiple sources of nondeterminism in your program. If there is more than one source of nondeterminism in your program, then eliminating one source will not yield determinism, but it will take you one step closer to determinism. I know what you mean though, flipping those switches didn't seem to fix the problem.
  • The fact that you followed-up by trying even more PyTorch-related settings suggests to me that you did not, in fact, first confirm that Python-related determinism had been achieved. What you're doing looks like a flip-all-the-switches approach, rather than a systematic debugging approach. If possible, it's best to try to break the problem down to isolate any remaining sources of nondeterminism. If possible, I recommend breaking down to a level where you have determinism and then building up until determinism is lost; that way you can isolate and (hopefully) eliminate the source(s) of nondeterminism; if a source of nondeterminism cannot be eliminated then a specific issue can be created in the appropriate open-source repository (e.g. PyTorch, DGL, Horovod, or Petastorm).
  • From your original comment, it looks like you're running on multiple devices (multiple GPUs in this case). Going to multi-GPU is the final step you should take after getting your full program running deterministically on a single GPU.
  • You also don't need to run your whole training cycle to detect nondeterminism. You can just run for a couple of steps and print some kind of summary (a sum or a hash) of your trainable variables. You should then be able to immediately see, in a few seconds, if there is nondeterminism. Once you have determinism for ten steps or so then the full training process will usually prove to be deterministic.
  • Now, when running on a single GPU, for only a few steps (and being able to compare the summary of trainable variables after each run), if you're still getting nondeterminism, then you can work on partitioning the problem by potentially eliminating possible sources from DGL vs PyTorch. For example, you might stub-out the DGL functionality, such as sampling, with something simple and definitely deterministic (e.g. a fixed sample pattern, but one that will not mask any possible downstream nondeterminism). By doing this, you should be able to confirm (or ensure) that the underlying PyTorch functionality is deterministic. I'm not deeply familiar with DGL yet, however, so my guidance here must be interpreted based on what partitioning is easily possible.
  • If your PyTorch-centric program is nondeterministic after a few training steps even though you've followed the guidelines to achieve PyTorch determinism, then you could partition the problem again by seeing if your trainable variables are reproducible before training begins. Another thing to confirm is that the data being fed to your model is reproducible (i.e. the summary of the data is the same on each step on both runs).
  • Once you have single GPU, PyTorch-centric determinism for a few steps, then you can add back in the DGL functionality piece-by-piece until you find what is introducing the nondeterminism. If it's some DGL functionality, then, at that point you can provide a very minimal program (in a DGL issue) that demonstrates the problem, a program that demonstrates the specific injection of nondeterminism from code in the DGL repository.
  • You might find that your whole program runs deterministically on a single GPU, however, and that the nondeterminism is getting introduced by your multi-GPU configuration.

Note to self (with others watching): there are many people commenting on this very generically-titled issue ("reproducibility issue of DGL") all of whom may, in fact, be dealing with completely different issues from each other. It's similar to an issue with the title "DGL program runs slow"; anyone who is unhappy with the speed of their program could then comment on the issue. There are many possible reasons for any given DGL program to run slow, many of them unrelated to DGL. There are many possible sources of nondeterminism in any given DGL program, many of them unrelated to DGL. Ideally, what we want is a specific little reproducer program with minimal functionality (no more complex than it needs to be) that demonstrates specifically what is not doing what it should, i.e. what in DGL (if anything) is running slow or what in DGL (if anything) is introducing nondeterminism.

@rickyxume
Copy link

@duncanriach Great comment! I get it, I've currently achieved determinism without using DGL's dataloader and running on single GPU.

My problem may be the one mentioned in @BarclayII 's comment , I currently still suspect that dgl.dataloading.NeighborSampler introduces nondetermination. If there are still problems, I'll try to give the minimum reproducible program.

Thanks !

@BarclayII
Copy link
Collaborator

@duncanriach Great comment! I get it, I've currently achieved determinism without using DGL's dataloader and running on single GPU.

My problem may be the one mentioned in @BarclayII 's comment , I currently still suspect that dgl.dataloading.NeighborSampler introduces nondetermination. If there are still problems, I'll try to give the minimum reproducible program.

Thanks !

If you absolutely want to remove the non-determinism in neighbor sampling, you could try setting num_workers=1 (which disables OpenMP in neighbor sampling since the sampling happens in subprocesses, but only in DGL 0.8+), or setting the environment variable OMP_NUM_THREADS=1.

@rickyxume
Copy link

@BarclayII Thx! It's done!!! Setting num_workers=1 works!

OMP_NUM_THREADS=1 does not seem to work. Anyway, my problem was finally solved and I learned a lot from you guys!

@BarclayII
Copy link
Collaborator

Please feel free to reopen if the above still did not work.

DGL Tracker automation moved this from Issue in triage to Done Mar 28, 2022
@decoherencer
Copy link
Contributor

For those working on Jupyter, set the seeds in the training cell itself, if you do want to reproduce across multiple runs.

@xxhu94
Copy link

xxhu94 commented Apr 1, 2022

This randomness is not observed when the dataset is small (like 6k nodes). But when the dataset is large (like 100,000 nodes), some randomness is seen around the 4th decimal place, which makes the experiment not strictly reproducible

@rickyxume
Copy link

This randomness is not observed when the dataset is small (like 6k nodes). But when the dataset is large (like 100,000 nodes), some randomness is seen around the 4th decimal place, which makes the experiment not strictly reproducible

I encountered a similar problem where I found that when all the reproducible settings discussed earlier were used, the results were deterministic when the number of rounds was small (e.g. 10 epoch), but became uncertain when the number of rounds became much larger (e.g. 100 epoch).

Also, I found that the GPU introduced more uncertainty than the CPU. Because the CPU's results are often more deterministic when compared to the GPU for the same number of rounds.

sys.platform --> linux
Python --> 3.8.12 (default, Oct 12 2021, 13:49:34) [GCC 7.5.0]
Numpy --> 1.21.2
Scikit-learn --> 1.0.2
PyTorch --> 1.11.0
DGL --> 0.8.0post1
CUDA available --> True
Device --> [('GPU 0', 'Tesla V100S-PCIE-32GB')]

@yhong4
Copy link

yhong4 commented Nov 21, 2022

@BarclayII @rickyxume Hello, I just encounter the same issue for SageConv. Just wondering how did you solve the issue? I'm pretty new to dgl. Where did you set num_workers=1?

Many thanks!

@jermainewang
Copy link
Member

@yhong4 You may want to open a new issue for more visibility.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
topic: PyTorch Issues only in PyTorch backend
Projects
No open projects
Development

No branches or pull requests