22// The .NET Foundation licenses this file to you under the MIT license.
33// See the LICENSE file in the project root for more information.
44
5+ using System . Collections . Generic ;
6+ using System . Linq ;
57using System . Threading . Tasks ;
68using Xunit ;
79
@@ -68,6 +70,55 @@ public static async Task CaptureAndRestoreEmptyContext()
6870 Assert . Equal ( local . Value , 12 ) ;
6971 }
7072
73+ [ Theory ]
74+ [ MemberData ( nameof ( GetCounts ) ) ]
75+ public static async Task CaptureAndRestoreNullAsyncLocals ( int count )
76+ {
77+ AsyncLocal < object > [ ] locals = new AsyncLocal < object > [ count ] ;
78+ for ( var i = 0 ; i < locals . Length ; i ++ )
79+ {
80+ locals [ i ] = new AsyncLocal < object > ( ) ;
81+ }
82+
83+ ExecutionContext ec = ExecutionContext . Capture ( ) ;
84+
85+ ExecutionContext . Run (
86+ ec ,
87+ _ =>
88+ {
89+ for ( var i = 0 ; i < locals . Length ; i ++ )
90+ {
91+ AsyncLocal < object > local = locals [ i ] ;
92+
93+ Assert . Null ( local . Value ) ;
94+ local . Value = 56 ;
95+ Assert . IsType < int > ( local . Value ) ;
96+ Assert . Equal ( 56 , ( int ) local . Value ) ;
97+ }
98+ } ,
99+ null ) ;
100+
101+ for ( var i = 0 ; i < locals . Length ; i ++ )
102+ {
103+ Assert . Null ( locals [ i ] . Value ) ;
104+ }
105+ }
106+
107+ [ Fact ]
108+ public static async Task CaptureAndRunOnFlowSupressedContext ( )
109+ {
110+ ExecutionContext . SuppressFlow ( ) ;
111+ try
112+ {
113+ ExecutionContext ec = ExecutionContext . Capture ( ) ;
114+ Assert . Throws < InvalidOperationException > ( ( ) => ExecutionContext . Run ( ec , _ => { } , null ) ) ;
115+ }
116+ finally
117+ {
118+ ExecutionContext . RestoreFlow ( ) ;
119+ }
120+ }
121+
71122 [ Fact ]
72123 public static async Task NotifyOnValuePropertyChange ( )
73124 {
@@ -365,10 +416,11 @@ await Run(async () =>
365416 Assert . Equal ( local . Value , 42 ) ;
366417 }
367418
368- [ Fact ]
369- public static async Task AddAndUpdateManyLocals_ValueType ( )
419+ [ Theory ]
420+ [ MemberData ( nameof ( GetCounts ) ) ]
421+ public static async Task AddAndUpdateManyLocals_ValueType ( int count )
370422 {
371- var locals = new AsyncLocal < int > [ 40 ] ;
423+ var locals = new AsyncLocal < int > [ count ] ;
372424 for ( int i = 0 ; i < locals . Length ; i ++ )
373425 {
374426 locals [ i ] = new AsyncLocal < int > ( ) ;
@@ -387,10 +439,11 @@ public static async Task AddAndUpdateManyLocals_ValueType()
387439 }
388440 }
389441
390- [ Fact ]
391- public static async Task AddUpdateAndRemoveManyLocals_ReferenceType ( )
442+ [ Theory ]
443+ [ MemberData ( nameof ( GetCounts ) ) ]
444+ public static async Task AddUpdateAndRemoveManyLocals_ReferenceType ( int count )
392445 {
393- var locals = new AsyncLocal < string > [ 40 ] ;
446+ var locals = new AsyncLocal < string > [ count ] ;
394447
395448 for ( int i = 0 ; i < locals . Length ; i ++ )
396449 {
@@ -419,5 +472,180 @@ public static async Task AddUpdateAndRemoveManyLocals_ReferenceType()
419472 }
420473 }
421474 }
475+
476+ [ Theory ]
477+ [ MemberData ( nameof ( GetCounts ) ) ]
478+ public static async Task AsyncLocalsUnwind ( int count )
479+ {
480+ AsyncLocal < object > [ ] asyncLocals = new AsyncLocal < object > [ count ] ;
481+
482+ ExecutionContext Default = ExecutionContext . Capture ( ) ;
483+ int [ ] manuallySetCounts = new int [ count ] ;
484+ int [ ] automaticallyUnsetCounts = new int [ count ] ;
485+ int [ ] automaticallySetCounts = new int [ count ] ;
486+ ExecutionContext [ ] capturedContexts = new ExecutionContext [ count ] ;
487+
488+ // Setup the AsyncLocals; capturing ExecutionContext for each level
489+ await SetLocalsRecursivelyAsync ( count - 1 ) ;
490+
491+ ValidateCounts ( thresholdIndex : 0 , maunalSets : 1 , automaticUnsets : 1 , automaticSets : 0 ) ;
492+ ValidateAsyncLocalsValuesNull ( ) ;
493+
494+ // Check Running with the contexts captured when setting the locals
495+ TestCapturedExecutionContexts ( ) ;
496+
497+ ExecutionContext . SuppressFlow ( ) ;
498+ try
499+ {
500+ // Re-check restoring, but starting with a suppressed flow
501+ TestCapturedExecutionContexts ( ) ;
502+ }
503+ finally
504+ {
505+ ExecutionContext . RestoreFlow ( ) ;
506+ }
507+
508+ // -- Local functions --
509+ void ValidateAsyncLocalsValuesNull ( )
510+ {
511+ // Check AsyncLocals haven't leaked
512+ for ( int i = 0 ; i < asyncLocals . Length ; i ++ )
513+ {
514+ Assert . Null ( asyncLocals [ i ] . Value ) ;
515+ }
516+ }
517+
518+ void ValidateAsyncLocalsValues ( int thresholdIndex )
519+ {
520+ for ( int localsIndex = 0 ; localsIndex < asyncLocals . Length ; localsIndex ++ )
521+ {
522+ if ( localsIndex >= thresholdIndex )
523+ {
524+ Assert . Equal ( localsIndex , ( int ) asyncLocals [ localsIndex ] . Value ) ;
525+ }
526+ else
527+ {
528+ Assert . Null ( asyncLocals [ localsIndex ] . Value ) ;
529+ }
530+ }
531+ }
532+
533+ void TestCapturedExecutionContexts ( )
534+ {
535+ for ( int contextIndex = 0 ; contextIndex < asyncLocals . Length ; contextIndex ++ )
536+ {
537+ ClearCounts ( ) ;
538+
539+ ExecutionContext . Run (
540+ capturedContexts [ contextIndex ] . CreateCopy ( ) ,
541+ ( o ) => TestCapturedExecutionContext ( ( int ) o ) ,
542+ contextIndex ) ;
543+
544+ // Validate locals have been restored to the Default context's values
545+ ValidateAsyncLocalsValuesNull ( ) ;
546+ }
547+ }
548+
549+ void TestCapturedExecutionContext ( int contextIndex )
550+ {
551+ ValidateCounts ( thresholdIndex : contextIndex , maunalSets : 0 , automaticUnsets : 0 , automaticSets : 1 ) ;
552+ // Validate locals have been restored to the outer context's values
553+ ValidateAsyncLocalsValues ( thresholdIndex : contextIndex ) ;
554+
555+ // Validate locals are correctly reset Running with a Default context from a non-Default context
556+ ExecutionContext . Run (
557+ Default . CreateCopy ( ) ,
558+ _ => ValidateAsyncLocalsValuesNull ( ) ,
559+ null ) ;
560+
561+ ValidateCounts ( thresholdIndex : contextIndex , maunalSets : 0 , automaticUnsets : 1 , automaticSets : 2 ) ;
562+ // Validate locals have been restored to the outer context's values
563+ ValidateAsyncLocalsValues ( thresholdIndex : contextIndex ) ;
564+
565+ for ( int innerContextIndex = 0 ; innerContextIndex < asyncLocals . Length ; innerContextIndex ++ )
566+ {
567+ // Validate locals are correctly restored Running with another non-Default context from a non-Default context
568+ ExecutionContext . Run (
569+ capturedContexts [ innerContextIndex ] . CreateCopy ( ) ,
570+ o => ValidateAsyncLocalsValues ( thresholdIndex : ( int ) o ) ,
571+ innerContextIndex ) ;
572+
573+ // Validate locals have been restored to the outer context's values
574+ ValidateAsyncLocalsValues ( thresholdIndex : contextIndex ) ;
575+ }
576+ }
577+
578+ void ValidateCounts ( int thresholdIndex , int maunalSets , int automaticUnsets , int automaticSets )
579+ {
580+ for ( int localsIndex = 0 ; localsIndex < asyncLocals . Length ; localsIndex ++ )
581+ {
582+ Assert . Equal ( localsIndex < thresholdIndex ? 0 : maunalSets , manuallySetCounts [ localsIndex ] ) ;
583+ Assert . Equal ( localsIndex < thresholdIndex ? 0 : automaticUnsets , automaticallyUnsetCounts [ localsIndex ] ) ;
584+ Assert . Equal ( localsIndex < thresholdIndex ? 0 : automaticSets , automaticallySetCounts [ localsIndex ] ) ;
585+ }
586+ }
587+
588+ // Synchronous function is async to create different ExectutionContexts for each set, and check async unwinding
589+ async Task SetLocalsRecursivelyAsync ( int index )
590+ {
591+ // Set AsyncLocal
592+ asyncLocals [ index ] = new AsyncLocal < object > ( CountValueChanges )
593+ {
594+ Value = index
595+ } ;
596+
597+ // Capture context with AsyncLocal set
598+ capturedContexts [ index ] = ExecutionContext . Capture ( ) ;
599+
600+ if ( index > 0 )
601+ {
602+ // Go deeper into async stack
603+ int nextIndex = index - 1 ;
604+ await SetLocalsRecursivelyAsync ( index - 1 ) ;
605+ // Set is undone by the await
606+ Assert . Null ( asyncLocals [ nextIndex ] . Value ) ;
607+ }
608+ }
609+
610+ void CountValueChanges ( AsyncLocalValueChangedArgs < object > args )
611+ {
612+ if ( ! args . ThreadContextChanged )
613+ {
614+ // Manual create, previous should be null
615+ Assert . Null ( args . PreviousValue ) ;
616+ Assert . IsType < int > ( args . CurrentValue ) ;
617+ manuallySetCounts [ ( int ) args . CurrentValue ] ++ ;
618+ }
619+ else
620+ {
621+ // Automatic change, only one value should be not null
622+ if ( args . CurrentValue != null )
623+ {
624+ Assert . Null ( args . PreviousValue ) ;
625+ Assert . IsType < int > ( args . CurrentValue ) ;
626+ automaticallySetCounts [ ( int ) args . CurrentValue ] ++ ;
627+ }
628+ else
629+ {
630+ Assert . Null ( args . CurrentValue ) ;
631+ Assert . NotNull ( args . PreviousValue ) ;
632+ Assert . IsType < int > ( args . PreviousValue ) ;
633+ automaticallyUnsetCounts [ ( int ) args . PreviousValue ] ++ ;
634+ }
635+ }
636+ }
637+
638+ void ClearCounts ( )
639+ {
640+ Array . Clear ( manuallySetCounts , 0 , count ) ;
641+ Array . Clear ( automaticallyUnsetCounts , 0 , count ) ;
642+ Array . Clear ( automaticallySetCounts , 0 , count ) ;
643+ }
644+ }
645+
646+ // The data structure that holds AsyncLocals changes based on size;
647+ // so it needs to be tested at a variety of sizes
648+ public static IEnumerable < object [ ] > GetCounts ( )
649+ => Enumerable . Range ( 1 , 40 ) . Select ( i => new object [ ] { i } ) ;
422650 }
423651}
0 commit comments